org.apache.spark.unsafe.types.UTF8String Scala Examples

The following examples show how to use org.apache.spark.unsafe.types.UTF8String. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: BooleanStatement.scala    From spark-snowflake   with Apache License 2.0 7 votes vote down vote up
package net.snowflake.spark.snowflake.pushdowns.querygeneration

import net.snowflake.spark.snowflake.{ConstantString, SnowflakeSQLStatement}
import org.apache.spark.sql.catalyst.expressions.{
  Attribute,
  Contains,
  EndsWith,
  EqualTo,
  Expression,
  GreaterThan,
  GreaterThanOrEqual,
  In,
  IsNotNull,
  IsNull,
  LessThan,
  LessThanOrEqual,
  Literal,
  Not,
  StartsWith
}
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String


private[querygeneration] object BooleanStatement {
  def unapply(
    expAttr: (Expression, Seq[Attribute])
  ): Option[SnowflakeSQLStatement] = {
    val expr = expAttr._1
    val fields = expAttr._2

    Option(expr match {
      case In(child, list) if list.forall(_.isInstanceOf[Literal]) =>
        convertStatement(child, fields) + "IN" +
          blockStatement(convertStatements(fields, list: _*))
      case IsNull(child) =>
        blockStatement(convertStatement(child, fields) + "IS NULL")
      case IsNotNull(child) =>
        blockStatement(convertStatement(child, fields) + "IS NOT NULL")
      case Not(child) => {
        child match {
          case EqualTo(left, right) =>
            blockStatement(
              convertStatement(left, fields) + "!=" +
                convertStatement(right, fields)
            )
          case GreaterThanOrEqual(left, right) =>
            convertStatement(LessThan(left, right), fields)
          case LessThanOrEqual(left, right) =>
            convertStatement(GreaterThan(left, right), fields)
          case GreaterThan(left, right) =>
            convertStatement(LessThanOrEqual(left, right), fields)
          case LessThan(left, right) =>
            convertStatement(GreaterThanOrEqual(left, right), fields)
          case _ =>
            ConstantString("NOT") +
              blockStatement(convertStatement(child, fields))
        }
      }
      case Contains(child, Literal(pattern: UTF8String, StringType)) =>
        convertStatement(child, fields) + "LIKE" + s"'%${pattern.toString}%'"
      case EndsWith(child, Literal(pattern: UTF8String, StringType)) =>
        convertStatement(child, fields) + "LIKE" + s"'%${pattern.toString}'"
      case StartsWith(child, Literal(pattern: UTF8String, StringType)) =>
        convertStatement(child, fields) + "LIKE" + s"'${pattern.toString}%'"

      case _ => null
    })
  }
} 
Example 2
Source File: ColumnarTestUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import scala.collection.immutable.HashSet
import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String

object ColumnarTestUtils {
  def makeNullRow(length: Int): GenericInternalRow = {
    val row = new GenericInternalRow(length)
    (0 until length).foreach(row.setNullAt)
    row
  }

  def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
    def randomBytes(length: Int) = {
      val bytes = new Array[Byte](length)
      Random.nextBytes(bytes)
      bytes
    }

    (columnType match {
      case NULL => null
      case BOOLEAN => Random.nextBoolean()
      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
      case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
      case INT => Random.nextInt()
      case LONG => Random.nextLong()
      case FLOAT => Random.nextFloat()
      case DOUBLE => Random.nextDouble()
      case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
      case BINARY => randomBytes(Random.nextInt(32))
      case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
      case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
      case STRUCT(_) =>
        new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
      case ARRAY(_) =>
        new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
      case MAP(_) =>
        ArrayBasedMapData(
          Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
      case _ => throw new IllegalArgumentException(s"Unknown column type $columnType")
    }).asInstanceOf[JvmType]
  }

  def makeRandomValues(
      head: ColumnType[_],
      tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)

  def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
    columnTypes.map(makeRandomValue(_))
  }

  def makeUniqueRandomValues[JvmType](
      columnType: ColumnType[JvmType],
      count: Int): Seq[JvmType] = {

    Iterator.iterate(HashSet.empty[JvmType]) { set =>
      set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
    }.drop(count).next().toSeq
  }

  def makeRandomRow(
      head: ColumnType[_],
      tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)

  def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
    val row = new GenericInternalRow(columnTypes.length)
    makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
      row(index) = value
    }
    row
  }

  def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
      columnType: NativeColumnType[T],
      count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = {

    val values = makeUniqueRandomValues(columnType, count)
    val rows = values.map { value =>
      val row = new GenericInternalRow(1)
      row(0) = value
      row
    }

    (values, rows)
  }
} 
Example 3
Source File: RowSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 4
Source File: FailureSafeParser.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class FailureSafeParser[IN](
    rawParser: IN => Seq[InternalRow],
    mode: ParseMode,
    schema: StructType,
    columnNameOfCorruptRecord: String) {

  private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord)
  private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord))
  private val resultRow = new GenericInternalRow(schema.length)
  private val nullResult = new GenericInternalRow(schema.length)

  // This function takes 2 parameters: an optional partial result, and the bad record. If the given
  // schema doesn't contain a field for corrupted record, we just return the partial result or a
  // row with all fields null. If the given schema contains a field for corrupted record, we will
  // set the bad record to this field, and set other fields according to the partial result or null.
  private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = {
    if (corruptFieldIndex.isDefined) {
      (row, badRecord) => {
        var i = 0
        while (i < actualSchema.length) {
          val from = actualSchema(i)
          resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull
          i += 1
        }
        resultRow(corruptFieldIndex.get) = badRecord()
        resultRow
      }
    } else {
      (row, _) => row.getOrElse(nullResult)
    }
  }

  def parse(input: IN): Iterator[InternalRow] = {
    try {
      rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
    } catch {
      case e: BadRecordException => mode match {
        case PermissiveMode =>
          Iterator(toResultRow(e.partialResult(), e.record))
        case DropMalformedMode =>
          Iterator.empty
        case FailFastMode =>
          throw new SparkException("Malformed records are detected in record parsing. " +
            s"Parse Mode: ${FailFastMode.name}.", e.cause)
      }
    }
  }
} 
Example 5
Source File: ComplexDataSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class ComplexDataSuite extends SparkFunSuite {
  def utf8(str: String): UTF8String = UTF8String.fromString(str)

  test("inequality tests for MapData") {
    // test data
    val testMap1 = Map(utf8("key1") -> 1)
    val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
    val testMap3 = Map(utf8("key1") -> 1)
    val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }

  test("GenericInternalRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericRow = genericRow.copy()
    assert(copiedGenericRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedGenericRow.getString(0) == "a")
  }

  test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val mutableRow = new SpecificInternalRow(Seq(StringType))
    mutableRow(0) = unsafeRow.getUTF8String(0)
    val copiedMutableRow = mutableRow.copy()
    assert(copiedMutableRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedMutableRow.getString(0) == "a")
  }

  test("GenericArrayData.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericArray = genericArray.copy()
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied array data should not be changed externally.
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
  }

  test("copy on nested complex type") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0))))
    val copied = arrayOfRow.copy()
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied data should not be changed externally.
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
  }
} 
Example 6
Source File: NumberConverterSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String

class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }

  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 7
Source File: GenerateUnsafeProjectionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class GenerateUnsafeProjectionSuite extends SparkFunSuite {
  test("Test unsafe projection string access pattern") {
    val dataType = (new StructType).add("a", StringType)
    val exprs = BoundReference(0, dataType, nullable = true) :: Nil
    val projection = GenerateUnsafeProjection.generate(exprs)
    val result = projection.apply(InternalRow(AlwaysNull))
    assert(!result.isNullAt(0))
    assert(result.getStruct(0, 1).isNullAt(0))
  }
}

object AlwaysNull extends InternalRow {
  override def numFields: Int = 1
  override def setNullAt(i: Int): Unit = {}
  override def copy(): InternalRow = this
  override def anyNull: Boolean = true
  override def isNullAt(ordinal: Int): Boolean = true
  override def update(i: Int, value: Any): Unit = notSupported
  override def getBoolean(ordinal: Int): Boolean = notSupported
  override def getByte(ordinal: Int): Byte = notSupported
  override def getShort(ordinal: Int): Short = notSupported
  override def getInt(ordinal: Int): Int = notSupported
  override def getLong(ordinal: Int): Long = notSupported
  override def getFloat(ordinal: Int): Float = notSupported
  override def getDouble(ordinal: Int): Double = notSupported
  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
  override def getUTF8String(ordinal: Int): UTF8String = notSupported
  override def getBinary(ordinal: Int): Array[Byte] = notSupported
  override def getInterval(ordinal: Int): CalendarInterval = notSupported
  override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
  override def getArray(ordinal: Int): ArrayData = notSupported
  override def getMap(ordinal: Int): MapData = notSupported
  override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
  private def notSupported: Nothing = throw new UnsupportedOperationException
} 
Example 8
Source File: MiscExpressionsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("assert_true") {
    intercept[RuntimeException] {
      checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null)
    }
    intercept[RuntimeException] {
      checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null)
    }
    intercept[RuntimeException] {
      checkEvaluation(AssertTrue(Literal.create(null, NullType)), null)
    }
    intercept[RuntimeException] {
      checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null)
    }
    checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null)
    checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null)
  }

  test("uuid") {
    def assertIncorrectEval(f: () => Unit): Unit = {
      intercept[Exception] {
        f()
      }.getMessage().contains("Incorrect evaluation")
    }

    checkEvaluation(Length(Uuid(Some(0))), 36)
    val r = new Random()
    val seed1 = Some(r.nextLong())
    val uuid1 = evaluate(Uuid(seed1)).asInstanceOf[UTF8String]
    checkEvaluation(Uuid(seed1), uuid1.toString)

    val seed2 = Some(r.nextLong())
    val uuid2 = evaluate(Uuid(seed2)).asInstanceOf[UTF8String]
    assertIncorrectEval(() => checkEvaluationWithoutCodegen(Uuid(seed1), uuid2))
    assertIncorrectEval(() => checkEvaluationWithGeneratedMutableProjection(Uuid(seed1), uuid2))
    assertIncorrectEval(() => checkEvalutionWithUnsafeProjection(Uuid(seed1), uuid2))
    assertIncorrectEval(() => checkEvaluationWithOptimization(Uuid(seed1), uuid2))
  }
} 
Example 9
Source File: StringUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.regex.{Pattern, PatternSyntaxException}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

  
  def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
    val funcNames = scala.collection.mutable.SortedSet.empty[String]
    pattern.trim().split("\\|").foreach { subPattern =>
      try {
        val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
        funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
      } catch {
        case _: PatternSyntaxException =>
      }
    }
    funcNames.toSeq
  }
} 
Example 10
Source File: CreateJacksonParser.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.json

import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text

import org.apache.spark.unsafe.types.UTF8String

private[sql] object CreateJacksonParser extends Serializable {
  def string(jsonFactory: JsonFactory, record: String): JsonParser = {
    jsonFactory.createParser(record)
  }

  def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
    val bb = record.getByteBuffer
    assert(bb.hasArray)

    val bain = new ByteArrayInputStream(
      bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())

    jsonFactory.createParser(new InputStreamReader(bain, "UTF-8"))
  }

  def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
    jsonFactory.createParser(record.getBytes, 0, record.getLength)
  }

  def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
    jsonFactory.createParser(record)
  }
} 
Example 11
Source File: converters.scala    From scylla-migrator   with Apache License 2.0 5 votes vote down vote up
package com.scylladb.migrator

import java.nio.charset.StandardCharsets
import java.util.UUID

import com.datastax.spark.connector.types.{
  CustomDriverConverter,
  NullableTypeConverter,
  PrimitiveColumnType,
  TypeConverter
}
import org.apache.spark.unsafe.types.UTF8String

import scala.reflect.runtime.universe.TypeTag

case object AnotherCustomUUIDConverter extends NullableTypeConverter[UUID] {
  def targetTypeTag = implicitly[TypeTag[UUID]]
  def convertPF = {
    case x: UUID   => x
    case x: String => UUID.fromString(x)
    case x: UTF8String =>
      UUID.fromString(new String(x.getBytes, StandardCharsets.UTF_8))
  }
}

case object CustomTimeUUIDType extends PrimitiveColumnType[UUID] {
  def scalaTypeTag = implicitly[TypeTag[UUID]]
  def cqlTypeName = "timeuuid"
  def converterToCassandra =
    new TypeConverter.OptionToNullConverter(AnotherCustomUUIDConverter)
}

case object CustomUUIDType extends PrimitiveColumnType[UUID] {
  def scalaTypeTag = implicitly[TypeTag[UUID]]
  def cqlTypeName = "uuid"
  def converterToCassandra =
    new TypeConverter.OptionToNullConverter(AnotherCustomUUIDConverter)
}

object CustomUUIDConverter extends CustomDriverConverter {
  import org.apache.spark.sql.{ types => catalystTypes }
  import com.datastax.driver.core.DataType
  import com.datastax.spark.connector.types.ColumnType

  override val fromDriverRowExtension: PartialFunction[DataType, ColumnType[_]] = {
    case dataType if dataType.getName == DataType.timeuuid().getName =>
      CustomTimeUUIDType
    case dataType if dataType.getName == DataType.uuid().getName =>
      CustomUUIDType
  }

  override val catalystDataType: PartialFunction[ColumnType[_], catalystTypes.DataType] = {
    case CustomTimeUUIDType => catalystTypes.StringType
    case CustomUUIDType     => catalystTypes.StringType
  }
} 
Example 12
Source File: PrefixComparatorsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.collection.unsafe.sort

import com.google.common.primitives.UnsignedBytes
import org.scalatest.prop.PropertyChecks
import org.apache.spark.SparkFunSuite
import org.apache.spark.unsafe.types.UTF8String

class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {

  test("String prefix comparator") {//字符串的前缀比较器

    def testPrefixComparison(s1: String, s2: String): Unit = {
      val utf8string1 = UTF8String.fromString(s1)
      val utf8string2 = UTF8String.fromString(s2)
      val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1)
      val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2)
      val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)

      val cmp = UnsignedBytes.lexicographicalComparator().compare(
        utf8string1.getBytes.take(8), utf8string2.getBytes.take(8))

      assert(
        (prefixComparisonResult == 0 && cmp == 0) ||
        (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) ||
        (prefixComparisonResult > 0 && s1.compareTo(s2) > 0))
    }

    // scalastyle:off
    val regressionTests = Table(
      ("s1", "s2"),
      ("abc", "世界"),
      ("你好", "世界"),
      ("你好123", "你好122")
    )
    // scalastyle:on

    forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
    forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
  }

  test("Binary prefix comparator") {//二进制前缀比较器

     def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
      for (i <- 0 until x.length; if i < y.length) {
        val res = x(i).compare(y(i))
        if (res != 0) return res
      }
      x.length - y.length
    }

    def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = {
      val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x)
      val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y)
      val prefixComparisonResult =
        PrefixComparators.BINARY.compare(s1Prefix, s2Prefix)
      assert(
        (prefixComparisonResult == 0) ||
        (prefixComparisonResult < 0 && compareBinary(x, y) < 0) ||
        (prefixComparisonResult > 0 && compareBinary(x, y) > 0))
    }

    // scalastyle:off
    val regressionTests = Table(
      ("s1", "s2"),
      ("abc", "世界"),
      ("你好", "世界"),
      ("你好123", "你好122")
    )
    // scalastyle:on

    forAll (regressionTests) { (s1: String, s2: String) =>
      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
    }
    forAll { (s1: String, s2: String) =>
      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
    }
  }
  //双前缀比较器正确处理NaN
  test("double prefix comparator handles NaNs properly") {
    val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
    val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
    assert(nan1.isNaN)
    assert(nan2.isNaN)
    val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
    val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2)
    assert(nan1Prefix === nan2Prefix)
    val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
    assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
  }

} 
Example 13
Source File: DDLTestSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext)
  }
}

case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan {

  override def schema: StructType =
    StructType(Seq(//StructType代表一张表,StructField代表一个字段
      StructField("intType", IntegerType, nullable = false,
        new MetadataBuilder().putString("comment", s"test comment $table").build()),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",//StructType代表一张表,StructField代表一个字段
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))
   //需要转换
  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    //依靠一个类型删掉黑客通过RDD[internalrow]回到RDD[行]
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sqlContext.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
}

class DDLTestSuite extends DataSourceTest with SharedSQLContext {
  protected override lazy val sql = caseInsensitiveContext.sql _

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql(
      """
      |CREATE TEMPORARY TABLE ddlPeople
      |USING org.apache.spark.sql.sources.DDLScanSource
      |OPTIONS (
      |  From '1',
      |  To '10',
      |  Table 'test1'
      |)
      """.stripMargin)
  }

  sqlTest(
      "describe ddlPeople",
      Seq(
        Row("intType", "int", "test comment test1"),
        Row("stringType", "string", ""),
        Row("dateType", "date", ""),
        Row("timestampType", "timestamp", ""),
        Row("doubleType", "double", ""),
        Row("bigintType", "bigint", ""),
        Row("tinyintType", "tinyint", ""),
        Row("decimalType", "decimal(10,0)", ""),
        Row("fixedDecimalType", "decimal(5,1)", ""),
        Row("binaryType", "binary", ""),
        Row("booleanType", "boolean", ""),
        Row("smallIntType", "smallint", ""),
        Row("floatType", "float", ""),
        Row("mapType", "map<string,string>", ""),
        Row("arrayType", "array<string>", ""),
        Row("structType", "struct<f1:string,f2:int>", "")
      ))
  //描述命令应该有正确的物理计划输出属性
  test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
    val attributes = sql("describe ddlPeople")
      .queryExecution.executedPlan.output
    assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
    assert(attributes.map(_.dataType).toSet === Set(StringType))
  }
} 
Example 14
Source File: ColumnarTestUtils.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.columnar

import scala.collection.immutable.HashSet
import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.{DataType, Decimal, AtomicType}
import org.apache.spark.unsafe.types.UTF8String
//列测试工具
object ColumnarTestUtils {
  def makeNullRow(length: Int): GenericMutableRow = {
    val row = new GenericMutableRow(length)
    (0 until length).foreach(row.setNullAt)
    row
  }
  //产生随机值
  def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
    def randomBytes(length: Int) = {
      val bytes = new Array[Byte](length)
      Random.nextBytes(bytes)
      bytes
    }

    (columnType match {
      case BOOLEAN => Random.nextBoolean()
      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
      case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
      case INT => Random.nextInt()
      case DATE => Random.nextInt()
      case LONG => Random.nextLong()
      case TIMESTAMP => Random.nextLong()
      case FLOAT => Random.nextFloat()
      case DOUBLE => Random.nextDouble()
      case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
      case BINARY => randomBytes(Random.nextInt(32))
      case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
      case _ =>
        // Using a random one-element map instead of an arbitrary object
        //使用随机一元映射代替任意对象
        Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
    }).asInstanceOf[JvmType]
  }

  def makeRandomValues(
      head: ColumnType[_],
      tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)

  def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
    columnTypes.map(makeRandomValue(_))
  }
 //使唯一随机值
  def makeUniqueRandomValues[JvmType](
      columnType: ColumnType[JvmType],
      count: Int): Seq[JvmType] = {

    Iterator.iterate(HashSet.empty[JvmType]) { set =>
      set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
    }.drop(count).next().toSeq
  }

  def makeRandomRow(
      head: ColumnType[_],
      tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)

  def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
    val row = new GenericMutableRow(columnTypes.length)
    makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
      row(index) = value
    }
    row
  }
  //使唯一值和单值行
  def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
      columnType: NativeColumnType[T],
      count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = {

    val values = makeUniqueRandomValues(columnType, count)
    val rows = values.map { value =>
      val row = new GenericMutableRow(1)
      row(0) = value
      row
    }

    (values, rows)
  }
} 
Example 15
Source File: ExtraStrategiesSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package test.org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String
//快速操作
case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    sparkContext.parallelize(Seq(row))
  }
  //Nil是一个空的List
  override def children: Seq[SparkPlan] = Nil
}
//测试策略
object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      FastOperator(attr.toAttribute :: Nil) :: Nil
    //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
    case _ => Nil
  }
}
//额外的策略集
class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {//插入一个额外的策略
    try {
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      sqlContext.experimental.extraStrategies = TestStrategy :: Nil

      val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      sqlContext.experimental.extraStrategies = Nil
    }
  }
} 
Example 16
Source File: SQLImplicits.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.types.StructField
import org.apache.spark.unsafe.types.UTF8String


  implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
    val dataType = StringType
    val rows = data.mapPartitions { iter =>
      val row = new SpecificMutableRow(dataType :: Nil)
      iter.map { v =>
        row.update(0, UTF8String.fromString(v))
        row: InternalRow
      }
    }
    DataFrameHolder(
      _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
  }
} 
Example 17
Source File: NumberConverterSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String


class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }
  //转换
  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 18
Source File: GeneratedProjectionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


class GeneratedProjectionSuite extends SparkFunSuite {
  //在更宽的桌子上产生预测
  test("generated projections on wider table") {
    val N = 1000
    val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
    val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
    val wideRow2 = new GenericInternalRow(
      (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
    val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
    val joined = new JoinedRow(wideRow1, wideRow2)
    val joinedSchema = StructType(schema1 ++ schema2)
    val nested = new JoinedRow(InternalRow(joined, joined), joined)
    val nestedSchema = StructType(
      Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)

    // test generated UnsafeProjection
    val unsafeProj = UnsafeProjection.create(nestedSchema)
    val unsafe: UnsafeRow = unsafeProj(nested)
    (0 until N).foreach { i =>
      val s = UTF8String.fromString((i + 1).toString)
      assert(i + 1 === unsafe.getInt(i + 2))
      assert(s === unsafe.getUTF8String(i + 2 + N))
      assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
      assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
      assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
      assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
    }

    // test generated SafeProjection
    val safeProj = FromUnsafeProjection(nestedSchema)
    val result = safeProj(unsafe)
    // Can't compare GenericInternalRow with JoinedRow directly
    (0 until N).foreach { i =>
      val r = i + 1
      val s = UTF8String.fromString((i + 1).toString)
      assert(r === result.getInt(i + 2))
      assert(s === result.getUTF8String(i + 2 + N))
      assert(r === result.getStruct(0, N * 2).getInt(i))
      assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
      assert(r === result.getStruct(1, N * 2).getInt(i))
      assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
    }

    // test generated MutableProjection
    val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
      BoundReference(i, f.dataType, true)
    }
    val mutableProj = GenerateMutableProjection.generate(exprs)()
    val row1 = mutableProj(result)
    assert(result === row1)
    val row2 = mutableProj(result)
    assert(result === row2)
  }

  test("generated unsafe projection with array of binary") {
    val row = InternalRow(
      Array[Byte](1, 2),
      new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4))))
    val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType]

    val unsafeProj = UnsafeProjection.create(fields)
    val unsafeRow: UnsafeRow = unsafeProj(row)
    assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2)))
    assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2)))
    assert(unsafeRow.getArray(1).isNullAt(1))
    assert(unsafeRow.getArray(1).getBinary(1) === null)
    assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4)))

    val safeProj = FromUnsafeProjection(fields)
    val row2 = safeProj(unsafeRow)
    assert(row2 === row)
  }
} 
Example 19
Source File: RowSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 20
Source File: ColumnarTestUtils.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import scala.collection.immutable.HashSet
import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String

object ColumnarTestUtils {
  def makeNullRow(length: Int): GenericInternalRow = {
    val row = new GenericInternalRow(length)
    (0 until length).foreach(row.setNullAt)
    row
  }

  def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
    def randomBytes(length: Int) = {
      val bytes = new Array[Byte](length)
      Random.nextBytes(bytes)
      bytes
    }

    (columnType match {
      case NULL => null
      case BOOLEAN => Random.nextBoolean()
      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
      case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
      case INT => Random.nextInt()
      case LONG => Random.nextLong()
      case FLOAT => Random.nextFloat()
      case DOUBLE => Random.nextDouble()
      case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
      case BINARY => randomBytes(Random.nextInt(32))
      case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
      case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
      case STRUCT(_) =>
        new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
      case ARRAY(_) =>
        new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
      case MAP(_) =>
        ArrayBasedMapData(
          Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
      case _ => throw new IllegalArgumentException(s"Unknown column type $columnType")
    }).asInstanceOf[JvmType]
  }

  def makeRandomValues(
      head: ColumnType[_],
      tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)

  def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
    columnTypes.map(makeRandomValue(_))
  }

  def makeUniqueRandomValues[JvmType](
      columnType: ColumnType[JvmType],
      count: Int): Seq[JvmType] = {

    Iterator.iterate(HashSet.empty[JvmType]) { set =>
      set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
    }.drop(count).next().toSeq
  }

  def makeRandomRow(
      head: ColumnType[_],
      tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)

  def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
    val row = new GenericInternalRow(columnTypes.length)
    makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
      row(index) = value
    }
    row
  }

  def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
      columnType: NativeColumnType[T],
      count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = {

    val values = makeUniqueRandomValues(columnType, count)
    val rows = values.map { value =>
      val row = new GenericInternalRow(1)
      row(0) = value
      row
    }

    (values, rows)
  }
} 
Example 21
Source File: NumberConverterSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String

class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }

  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 22
Source File: MapDataSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 23
Source File: StringUtils.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.regex.{Pattern, PatternSyntaxException}

import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

  // replace the _ with .{1} exactly match 1 time of any character
  // replace the % with .*, match 0 or more times with any character
  def escapeLikeRegex(v: String): String = {
    if (!v.isEmpty) {
      "(?s)" + (' ' +: v.init).zip(v).flatMap {
        case (prev, '\\') => ""
        case ('\\', c) =>
          c match {
            case '_' => "_"
            case '%' => "%"
            case _ => Pattern.quote("\\" + c)
          }
        case (prev, c) =>
          c match {
            case '_' => "."
            case '%' => ".*"
            case _ => Pattern.quote(Character.toString(c))
          }
      }.mkString
    } else {
      v
    }
  }

  private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
  private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)

  def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
  def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)

  
  def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
    val funcNames = scala.collection.mutable.SortedSet.empty[String]
    pattern.trim().split("\\|").foreach { subPattern =>
      try {
        val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
        funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
      } catch {
        case _: PatternSyntaxException =>
      }
    }
    funcNames.toSeq
  }
} 
Example 24
Source File: similarityFunctions.scala    From spark-stringmetric   with MIT License 5 votes vote down vote up
package com.github.mrpowers.spark.stringmetric.expressions

import com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions
import org.apache.commons.text.similarity.CosineDistance
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{
  CodegenContext,
  ExprCode
}
import org.apache.spark.sql.types.{ DataType, IntegerType, StringType }


trait UTF8StringFunctionsHelper {
  val stringFuncs: String = "com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions"
}

trait StringString2IntegerExpression
extends ImplicitCastInputTypes
with NullIntolerant
with UTF8StringFunctionsHelper { self: BinaryExpression =>
  override def dataType: DataType = IntegerType
  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

  protected override def nullSafeEval(left: Any, right: Any): Any = -1
}

case class HammingDistance(left: Expression, right: Expression)
extends BinaryExpression with StringString2IntegerExpression {
  override def prettyName: String = "hamming"

  override def nullSafeEval(leftVal: Any, righValt: Any): Any = {
    val leftStr = left.asInstanceOf[UTF8String]
    val rightStr = right.asInstanceOf[UTF8String]
    UTF8StringFunctions.hammingDistance(leftStr, rightStr)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    defineCodeGen(ctx, ev, (s1, s2) => s"$stringFuncs.hammingDistance($s1, $s2)")
  }
} 
Example 25
Source File: PgWireProtocolSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service.postgresql.protocol.v3

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.SQLException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class PgWireProtocolSuite extends SparkFunSuite {

  val conf = new SQLConf()

  test("DataRow") {
    val v3Protocol = new PgWireProtocol(65536)
    val row = new GenericInternalRow(2)
    row.update(0, 8)
    row.update(1, UTF8String.fromString("abcdefghij"))
    val schema = StructType.fromDDL("a INT, b STRING")
    val rowConverters = PgRowConverters(conf, schema, Seq(true, false))
    val data = v3Protocol.DataRow(row, rowConverters)
    val bytes = ByteBuffer.wrap(data)
    assert(bytes.get() === 'D'.toByte)
    assert(bytes.getInt === 28)
    assert(bytes.getShort === 2)
    assert(bytes.getInt === 4)
    assert(bytes.getInt === 8)
    assert(bytes.getInt === 10)
    assert(data.slice(19, 30) === "abcdefghij".getBytes(StandardCharsets.UTF_8))
  }

  test("Fails when message buffer overflowed") {
    val v3Protocol = new PgWireProtocol(4)
    val row = new GenericInternalRow(1)
    row.update(0, UTF8String.fromString("abcdefghijk"))
    val schema = StructType.fromDDL("a STRING")
    val rowConverters = PgRowConverters(conf, schema, Seq(false))
    val errMsg = intercept[SQLException] {
      v3Protocol.DataRow(row, rowConverters)
    }.getMessage
    assert(errMsg.contains(
      "Cannot generate a V3 protocol message because buffer is not enough for the message. " +
        "To avoid this exception, you might set higher value at " +
        "'spark.sql.server.messageBufferSizeInBytes'")
    )
  }
} 
Example 26
Source File: TestUtils.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server

import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.unsafe.types.UTF8String

object TestUtils {

  def toSparkIntervalString(pgIntervalString: String): String = {
    // See: org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit
    val sparkIntervalString = UTF8String.fromString(pgIntervalString
      .replace("years", "year")
      .replace("mons", "month")
      .replace("days", "day")
      .replace("hours", "hour")
      .replace("mins", "minute")
      .replace("secs", "second")
    )
    IntervalUtils.stringToInterval(sparkIntervalString).toString
  }
} 
Example 27
Source File: IntegerEncodedStringColumnBuffer.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.colbuffer.string

import com.actian.spark_vector.colbuffer._
import com.actian.spark_vector.colbuffer.util.StringConversion
import com.actian.spark_vector.vector.VectorDataType

import org.apache.spark.unsafe.types.UTF8String

import java.nio.ByteBuffer

private[colbuffer] abstract class IntegerEncodedStringColumnBuffer(p: ColumnBufferBuildParams)
    extends ColumnBuffer[String, UTF8String](p.name, p.maxValueCount, IntSize, IntSize, p.nullable) {
  protected final val Whitespace = '\u0020'

  override def put(source: String, buffer: ByteBuffer): Unit = if (source.isEmpty()) {
    buffer.putInt(Whitespace)
  } else {
    buffer.putInt(encode(source))
  }

  protected def encode(value: String): Int

  override def get(buffer: ByteBuffer): UTF8String = UTF8String.fromBytes(Character.toChars(buffer.getInt()).map(_.toByte))
}

private class ConstantLengthSingleByteStringColumnBuffer(p: ColumnBufferBuildParams) extends IntegerEncodedStringColumnBuffer(p) {
  override protected def encode(value: String): Int = if (StringConversion.truncateToUTF8Bytes(value, 1).length == 0) {
    Whitespace
  } else {
    value.codePointAt(0)
  }
}

private class ConstantLengthSingleCharStringColumnBuffer(p: ColumnBufferBuildParams) extends IntegerEncodedStringColumnBuffer(p) {
  override protected def encode(value: String): Int = if (Character.isHighSurrogate(value.charAt(0))) {
    Whitespace
  } else {
    value.codePointAt(0)
  }
}


private[colbuffer] object IntegerEncodedStringColumnBuffer extends ColumnBufferBuilder {
  private val buildPartial: PartialFunction[ColumnBufferBuildParams, ColumnBufferBuildParams] = {
    case p if p.precision == 1 => p
  }

  override private[colbuffer] val build: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildPartial andThenPartial {
    (ofDataType(VectorDataType.CharType) andThen { new ConstantLengthSingleByteStringColumnBuffer(_) }) orElse
      (ofDataType(VectorDataType.NcharType) andThen { new ConstantLengthSingleCharStringColumnBuffer(_) })
  }
} 
Example 28
Source File: ByteEncodedStringColumnBuffer.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.colbuffer.string

import com.actian.spark_vector.colbuffer._
import com.actian.spark_vector.colbuffer.util.StringConversion
import com.actian.spark_vector.vector.VectorDataType

import org.apache.spark.unsafe.types.UTF8String

import java.nio.ByteBuffer

private[colbuffer] abstract class ByteEncodedStringColumnBuffer(p: ColumnBufferBuildParams)
    extends ColumnBuffer[String, UTF8String](p.name, p.maxValueCount, p.precision + 1, ByteSize, p.nullable) {
  override def put(source: String, buffer: ByteBuffer): Unit = {
    if (source.exists(c => (c >= '\uD800') && (c <= '\uDFFF') || (c == '\u0000')))
        throw new Exception(s"Illegal character in column '${p.name}' in string '$source'")
    buffer.put(encode(source))
    buffer.put(0.toByte)
  }

  protected def encode(value: String): Array[Byte]

  override def get(buffer: ByteBuffer): UTF8String = {
    
private[colbuffer] object ByteEncodedStringColumnBuffer extends ColumnBufferBuilder {
  private val buildConstLenMultiPartial: PartialFunction[ColumnBufferBuildParams, ColumnBufferBuildParams] = {
    case p if p.precision > 1 => p
  }

  private val buildConstLenMulti: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildConstLenMultiPartial andThenPartial {
    (ofDataType(VectorDataType.CharType) andThen { new ByteLengthLimitedStringColumnBuffer(_) }) orElse
      (ofDataType(VectorDataType.NcharType) andThen { new CharLengthLimitedStringColumnBuffer(_) })
  }

  private val buildVarLenPartial: PartialFunction[ColumnBufferBuildParams, ColumnBufferBuildParams] = {
    case p if p.precision > 0 => p
  }

  private val buildVarLen: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildVarLenPartial andThenPartial {
    (ofDataType(VectorDataType.VarcharType) andThen { new ByteLengthLimitedStringColumnBuffer(_) }) orElse
      (ofDataType(VectorDataType.NvarcharType) andThen { new CharLengthLimitedStringColumnBuffer(_) })
  }

  override private[colbuffer] val build: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildConstLenMulti orElse buildVarLen
} 
Example 29
Source File: MeanSubstitute.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.sql.dsl._
import io.projectglow.sql.util.RewriteAfterResolution


case class MeanSubstitute(array: Expression, missingValue: Expression)
    extends RewriteAfterResolution {

  override def children: Seq[Expression] = Seq(array, missingValue)

  def this(array: Expression) = {
    this(array, Literal(-1))
  }

  private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType

  // A value is considered missing if it is NaN, null or equal to the missing value parameter
  def isMissing(arrayElement: Expression): Predicate =
    IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue

  def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = {
    val sumName = Literal(UTF8String.fromString("sum"), StringType)
    val countName = Literal(UTF8String.fromString("count"), StringType)
    namedStruct(sumName, sumValue, countName, countValue)
  }

  // Update sum and count with array element if not missing
  def updateSumAndCountConditionally(
      stateStruct: Expression,
      arrayElement: Expression): Expression = {
    If(
      isMissing(arrayElement),
      // If value is missing, do not update sum and count
      stateStruct,
      // If value is not missing, add to sum and increment count
      createNamedStruct(
        stateStruct.getField("sum") + arrayElement,
        stateStruct.getField("count") + 1)
    )
  }

  // Calculate mean for imputation
  def calculateMean(stateStruct: Expression): Expression = {
    If(
      stateStruct.getField("count") > 0,
      // If non-missing values were found, calculate the average
      stateStruct.getField("sum") / stateStruct.getField("count"),
      // If all values were missing, substitute with missing value
      missingValue
    )
  }

  lazy val arrayMean: Expression = {
    // Sum and count of non-missing values
    array.aggregate(
      createNamedStruct(Literal(0d), Literal(0L)),
      updateSumAndCountConditionally,
      calculateMean
    )
  }

  def substituteWithMean(arrayElement: Expression): Expression = {
    If(isMissing(arrayElement), arrayMean, arrayElement)
  }

  override def rewrite: Expression = {
    if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.")
    }

    if (!missingValue.dataType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.")
    }

    // Replace missing values with the provided strategy
    array.arrayTransform(substituteWithMean(_))
  }
} 
Example 30
Source File: PlinkRowToInternalRowConverter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.plink

import org.apache.spark.sql.SQLUtils.structFieldsEqualExceptNullability
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types.{ArrayType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.RowConverter


class PlinkRowToInternalRowConverter(schema: StructType) extends GlowLogging {

  private val homAlt = new GenericArrayData(Array(1, 1))
  private val missing = new GenericArrayData(Array(-1, -1))
  private val het = new GenericArrayData(Array(0, 1))
  private val homRef = new GenericArrayData(Array(0, 0))

  private def twoBitsToCalls(twoBits: Int): GenericArrayData = {
    twoBits match {
      case 0 => homAlt // Homozygous for first (alternate) allele
      case 1 => missing // Missing genotype
      case 2 => het // Heterozygous
      case 3 => homRef // Homozygous for second (reference) allele
    }
  }

  private val converter = {
    val fns = schema.map { field =>
      val fn: RowConverter.Updater[(Array[UTF8String], Array[Byte])] = field match {
        case f if f.name == VariantSchemas.genotypesFieldName =>
          val gSchema = f.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]
          val converter = makeGenotypeConverter(gSchema)
          (samplesAndBlock, r, i) => {
            val genotypes = new Array[Any](samplesAndBlock._1.length)
            var sampleIdx = 0
            while (sampleIdx < genotypes.length) {
              val sample = samplesAndBlock._1(sampleIdx)
              // Get the relevant 2 bits for the sample from the block
              // The i-th sample's call bits are the (i%4)-th pair within the (i/4)-th block
              val twoBits = samplesAndBlock._2(sampleIdx / 4) >> (2 * (sampleIdx % 4)) & 3
              genotypes(sampleIdx) = converter((sample, twoBits))
              sampleIdx += 1
            }
            r.update(i, new GenericArrayData(genotypes))
          }
        case _ =>
          // BED file only contains genotypes
          (_, _, _) => ()
      }
      fn
    }
    new RowConverter[(Array[UTF8String], Array[Byte])](schema, fns.toArray)
  }

  private def makeGenotypeConverter(gSchema: StructType): RowConverter[(UTF8String, Int)] = {
    val functions = gSchema.map { field =>
      val fn: RowConverter.Updater[(UTF8String, Int)] = field match {
        case f if structFieldsEqualExceptNullability(f, VariantSchemas.sampleIdField) =>
          (sampleAndTwoBits, r, i) => {
            r.update(i, sampleAndTwoBits._1)
          }
        case f if structFieldsEqualExceptNullability(f, VariantSchemas.callsField) =>
          (sampleAndTwoBits, r, i) => r.update(i, twoBitsToCalls(sampleAndTwoBits._2))
        case f =>
          logger.info(
            s"Genotype field $f cannot be derived from PLINK files. It will be null " +
            s"for each sample."
          )
          (_, _, _) => ()
      }
      fn
    }
    new RowConverter[(UTF8String, Int)](gSchema, functions.toArray)
  }

  def convertRow(
      bimRow: InternalRow,
      sampleIds: Array[UTF8String],
      gtBlock: Array[Byte]): InternalRow = {
    converter((sampleIds, gtBlock), bimRow)
  }
} 
Example 31
Source File: UTF8TextOutputFormatter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.transformers.pipe

import java.io.InputStream

import scala.collection.JavaConverters._

import org.apache.commons.io.IOUtils
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String


class UTF8TextOutputFormatter() extends OutputFormatter {

  override def makeIterator(stream: InputStream): Iterator[Any] = {
    val schema = StructType(Seq(StructField("text", StringType)))
    val iter = IOUtils.lineIterator(stream, "UTF-8").asScala.map { s =>
      new GenericInternalRow(Array(UTF8String.fromString(s)): Array[Any])
    }
    Iterator(schema) ++ iter
  }
}

class UTF8TextOutputFormatterFactory extends OutputFormatterFactory {
  override def name: String = "text"

  override def makeOutputFormatter(options: Map[String, String]): OutputFormatter = {
    new UTF8TextOutputFormatter
  }
} 
Example 32
Source File: CouchbaseSink.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.sql.streaming

import com.couchbase.spark.Logging
import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.types.StringType
import com.couchbase.spark.sql._
import com.couchbase.spark._
import com.couchbase.client.core.CouchbaseException
import com.couchbase.client.java.document.JsonDocument
import com.couchbase.client.java.document.json.JsonObject
import scala.concurrent.duration._



class CouchbaseSink(options: Map[String, String]) extends Sink with Logging {

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    val bucketName = options.get("bucket").orNull
    val idFieldName = options.getOrElse("idField", DefaultSource.DEFAULT_DOCUMENT_ID_FIELD)
    val removeIdField = options.getOrElse("removeIdField", "true").toBoolean
    val timeout = options.get("timeout").map(v => Duration(v.toLong, MILLISECONDS))

    val createDocument = options.get("expiry").map(_.toInt)
      .map(expiry => (id: String, content: JsonObject) => JsonDocument.create(id, expiry, content))
      .getOrElse((id: String, content: JsonObject) => JsonDocument.create(id, content))

    data.toJSON
      .queryExecution
      .toRdd
      .map(_.get(0, StringType).asInstanceOf[UTF8String].toString())
      .map { rawJson =>
          val encoded = JsonObject.fromJson(rawJson)
          val id = encoded.get(idFieldName)

          if (id == null) {
              throw new Exception(s"Could not find ID field $idFieldName in $encoded")
          }

          if (removeIdField) {
              encoded.removeKey(idFieldName)
          }

          createDocument(id.toString, encoded)
      }
      .saveToCouchbase(bucketName, StoreMode.UPSERT, timeout)
  }

} 
Example 33
Source File: Utils.scala    From hbase-connectors   with Apache License 2.0 5 votes vote down vote up
package org.apache.hadoop.hbase.spark.datasources

import java.sql.{Date, Timestamp}

import org.apache.hadoop.hbase.spark.AvroSerdes
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.yetus.audience.InterfaceAudience;

@InterfaceAudience.Private
object Utils {

  
  def hbaseFieldToScalaType(
      f: Field,
      src: Array[Byte],
      offset: Int,
      length: Int): Any = {
    if (f.exeSchema.isDefined) {
      // If we have avro schema defined, use it to get record, and then convert them to catalyst data type
      val m = AvroSerdes.deserialize(src, f.exeSchema.get)
      val n = f.avroToCatalyst.map(_(m))
      n.get
    } else  {
      // Fall back to atomic type
      f.dt match {
        case BooleanType => src(offset) != 0
        case ByteType => src(offset)
        case ShortType => Bytes.toShort(src, offset)
        case IntegerType => Bytes.toInt(src, offset)
        case LongType => Bytes.toLong(src, offset)
        case FloatType => Bytes.toFloat(src, offset)
        case DoubleType => Bytes.toDouble(src, offset)
        case DateType => new Date(Bytes.toLong(src, offset))
        case TimestampType => new Timestamp(Bytes.toLong(src, offset))
        case StringType => UTF8String.fromBytes(src, offset, length)
        case BinaryType =>
          val newArray = new Array[Byte](length)
          System.arraycopy(src, offset, newArray, 0, length)
          newArray
        // TODO: SparkSqlSerializer.deserialize[Any](src)
        case _ => throw new Exception(s"unsupported data type ${f.dt}")
      }
    }
  }

  // convert input to data type
  def toBytes(input: Any, field: Field): Array[Byte] = {
    if (field.schema.isDefined) {
      // Here we assume the top level type is structType
      val record = field.catalystToAvro(input)
      AvroSerdes.serialize(record, field.schema.get)
    } else {
      field.dt match {
        case BooleanType => Bytes.toBytes(input.asInstanceOf[Boolean])
        case ByteType => Array(input.asInstanceOf[Number].byteValue)
        case ShortType => Bytes.toBytes(input.asInstanceOf[Number].shortValue)
        case IntegerType => Bytes.toBytes(input.asInstanceOf[Number].intValue)
        case LongType => Bytes.toBytes(input.asInstanceOf[Number].longValue)
        case FloatType => Bytes.toBytes(input.asInstanceOf[Number].floatValue)
        case DoubleType => Bytes.toBytes(input.asInstanceOf[Number].doubleValue)
        case DateType | TimestampType => Bytes.toBytes(input.asInstanceOf[java.util.Date].getTime)
        case StringType => Bytes.toBytes(input.toString)
        case BinaryType => input.asInstanceOf[Array[Byte]]
        case _ => throw new Exception(s"unsupported data type ${field.dt}")
      }
    }
  }
} 
Example 34
Source File: GenomicIntervalStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.utvf

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{DataFrame, GenomicInterval, SparkSession, Strategy}
import org.apache.spark.unsafe.types.UTF8String

case class GIntervalRow(contigName: String, start: Int, end: Int)
class GenomicIntervalStrategy( spark: SparkSession) extends Strategy with Serializable  {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {

    case GenomicInterval(contigName, start, end,output) => GenomicIntervalPlan(plan,spark,GIntervalRow(contigName,start,end),output) :: Nil
    case _ => Nil

  }
}

case class GenomicIntervalPlan(plan: LogicalPlan, spark: SparkSession,interval:GIntervalRow, output: Seq[Attribute]) extends SparkPlan with Serializable {
  def doExecute(): org.apache.spark.rdd.RDD[InternalRow] = {
    import spark.implicits._

    lazy val genomicInterval = spark.createDataset(Seq(interval))
    genomicInterval
        .rdd
      .map(r=>{
        val proj =  UnsafeProjection.create(schema)
        proj.apply(InternalRow.fromSeq(Seq(UTF8String.fromString(r.contigName),r.start,r.end)))
        }
      )
  }
  def children: Seq[SparkPlan] = Nil
} 
Example 35
Source File: Serialize.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.impl.expressions

import java.io.ByteArrayOutputStream

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, _}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.opencypher.morpheus.impl.expressions.EncodeLong.encodeLong
import org.opencypher.morpheus.impl.expressions.Serialize._
import org.opencypher.okapi.impl.exception


case class Serialize(children: Seq[Expression]) extends Expression {

  override def dataType: DataType = BinaryType

  override def nullable: Boolean = false

  // TODO: Only write length if more than one column is serialized
  override def eval(input: InternalRow): Any = {
    // TODO: Reuse from a pool instead of allocating a new one for each serialization
    val out = new ByteArrayOutputStream()
    children.foreach { child =>
      child.dataType match {
        case BinaryType => write(child.eval(input).asInstanceOf[Array[Byte]], out)
        case StringType => write(child.eval(input).asInstanceOf[UTF8String], out)
        case IntegerType => write(child.eval(input).asInstanceOf[Int], out)
        case LongType => write(child.eval(input).asInstanceOf[Long], out)
        case other => throw exception.UnsupportedOperationException(s"Cannot serialize Spark data type $other.")
      }
    }
    out.toByteArray
  }

  override protected def doGenCode(
    ctx: CodegenContext,
    ev: ExprCode
  ): ExprCode = {
    ev.isNull = FalseLiteral
    val out = ctx.freshName("out")
    val serializeChildren = children.map { child =>
      val childEval = child.genCode(ctx)
      s"""|${childEval.code}
          |if (!${childEval.isNull}) {
          |  ${Serialize.getClass.getName.dropRight(1)}.write(${childEval.value}, $out);
          |}""".stripMargin
    }.mkString("\n")
    val baos = classOf[ByteArrayOutputStream].getName
    ev.copy(
      code = code"""|$baos $out = new $baos();
          |$serializeChildren
          |byte[] ${ev.value} = $out.toByteArray();""".stripMargin)
  }

}

object Serialize {

  val supportedTypes: Set[DataType] = Set(BinaryType, StringType, IntegerType, LongType)

  @inline final def write(value: Array[Byte], out: ByteArrayOutputStream): Unit = {
    out.write(encodeLong(value.length))
    out.write(value)
  }

  @inline final def write(
    value: Boolean,
    out: ByteArrayOutputStream
  ): Unit = write(if (value) 1.toLong else 0.toLong, out)

  @inline final def write(value: Byte, out: ByteArrayOutputStream): Unit = write(value.toLong, out)

  @inline final def write(value: Int, out: ByteArrayOutputStream): Unit = write(value.toLong, out)

  @inline final def write(value: Long, out: ByteArrayOutputStream): Unit = write(encodeLong(value), out)

  @inline final def write(value: UTF8String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out)

  @inline final def write(value: String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out)

} 
Example 36
Source File: PrefixComparatorsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.collection.unsafe.sort

import com.google.common.primitives.UnsignedBytes
import org.scalatest.prop.PropertyChecks
import org.apache.spark.SparkFunSuite
import org.apache.spark.unsafe.types.UTF8String

class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {

  test("String prefix comparator") {

    def testPrefixComparison(s1: String, s2: String): Unit = {
      val utf8string1 = UTF8String.fromString(s1)
      val utf8string2 = UTF8String.fromString(s2)
      val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1)
      val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2)
      val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)

      val cmp = UnsignedBytes.lexicographicalComparator().compare(
        utf8string1.getBytes.take(8), utf8string2.getBytes.take(8))

      assert(
        (prefixComparisonResult == 0 && cmp == 0) ||
        (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) ||
        (prefixComparisonResult > 0 && s1.compareTo(s2) > 0))
    }

    // scalastyle:off
    val regressionTests = Table(
      ("s1", "s2"),
      ("abc", "世界"),
      ("你好", "世界"),
      ("你好123", "你好122")
    )
    // scalastyle:on

    forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
    forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
  }

  test("Binary prefix comparator") {

     def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
      for (i <- 0 until x.length; if i < y.length) {
        val res = x(i).compare(y(i))
        if (res != 0) return res
      }
      x.length - y.length
    }

    def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = {
      val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x)
      val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y)
      val prefixComparisonResult =
        PrefixComparators.BINARY.compare(s1Prefix, s2Prefix)
      assert(
        (prefixComparisonResult == 0) ||
        (prefixComparisonResult < 0 && compareBinary(x, y) < 0) ||
        (prefixComparisonResult > 0 && compareBinary(x, y) > 0))
    }

    // scalastyle:off
    val regressionTests = Table(
      ("s1", "s2"),
      ("abc", "世界"),
      ("你好", "世界"),
      ("你好123", "你好122")
    )
    // scalastyle:on

    forAll (regressionTests) { (s1: String, s2: String) =>
      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
    }
    forAll { (s1: String, s2: String) =>
      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
    }
  }

  test("double prefix comparator handles NaNs properly") {
    val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
    val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
    assert(nan1.isNaN)
    assert(nan2.isNaN)
    val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
    val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2)
    assert(nan1Prefix === nan2Prefix)
    val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
    assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
  }

} 
Example 37
Source File: DDLTestSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext)
  }
}

case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan {

  override def schema: StructType =
    StructType(Seq(
      StructField("intType", IntegerType, nullable = false,
        new MetadataBuilder().putString("comment", s"test comment $table").build()),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))

  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sqlContext.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
}

class DDLTestSuite extends DataSourceTest with SharedSQLContext {
  protected override lazy val sql = caseInsensitiveContext.sql _

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql(
      """
      |CREATE TEMPORARY TABLE ddlPeople
      |USING org.apache.spark.sql.sources.DDLScanSource
      |OPTIONS (
      |  From '1',
      |  To '10',
      |  Table 'test1'
      |)
      """.stripMargin)
  }

  sqlTest(
      "describe ddlPeople",
      Seq(
        Row("intType", "int", "test comment test1"),
        Row("stringType", "string", ""),
        Row("dateType", "date", ""),
        Row("timestampType", "timestamp", ""),
        Row("doubleType", "double", ""),
        Row("bigintType", "bigint", ""),
        Row("tinyintType", "tinyint", ""),
        Row("decimalType", "decimal(10,0)", ""),
        Row("fixedDecimalType", "decimal(5,1)", ""),
        Row("binaryType", "binary", ""),
        Row("booleanType", "boolean", ""),
        Row("smallIntType", "smallint", ""),
        Row("floatType", "float", ""),
        Row("mapType", "map<string,string>", ""),
        Row("arrayType", "array<string>", ""),
        Row("structType", "struct<f1:string,f2:int>", "")
      ))

  test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
    val attributes = sql("describe ddlPeople")
      .queryExecution.executedPlan.output
    assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
    assert(attributes.map(_.dataType).toSet === Set(StringType))
  }
} 
Example 38
Source File: ExtraStrategiesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package test.org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    sparkContext.parallelize(Seq(row))
  }

  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      sqlContext.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      sqlContext.experimental.extraStrategies = Nil
    }
  }
} 
Example 39
Source File: RowSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericMutableRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificMutableRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("serialize w/ kryo") {
    val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first()
    val serializer = new SparkSqlSerializer(sparkContext.getConf)
    val instance = serializer.newInstance()
    val ser = instance.serialize(row)
    val de = instance.deserialize(ser).asInstanceOf[Row]
    assert(de === row)
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 40
Source File: ColumnarTestUtils.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import scala.collection.immutable.HashSet
import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData}
import org.apache.spark.sql.types.{AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String

object ColumnarTestUtils {
  def makeNullRow(length: Int): GenericMutableRow = {
    val row = new GenericMutableRow(length)
    (0 until length).foreach(row.setNullAt)
    row
  }

  def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
    def randomBytes(length: Int) = {
      val bytes = new Array[Byte](length)
      Random.nextBytes(bytes)
      bytes
    }

    (columnType match {
      case NULL => null
      case BOOLEAN => Random.nextBoolean()
      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
      case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
      case INT => Random.nextInt()
      case LONG => Random.nextLong()
      case FLOAT => Random.nextFloat()
      case DOUBLE => Random.nextDouble()
      case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
      case BINARY => randomBytes(Random.nextInt(32))
      case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
      case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
      case STRUCT(_) =>
        new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
      case ARRAY(_) =>
        new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
      case MAP(_) =>
        ArrayBasedMapData(
          Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
    }).asInstanceOf[JvmType]
  }

  def makeRandomValues(
      head: ColumnType[_],
      tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)

  def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
    columnTypes.map(makeRandomValue(_))
  }

  def makeUniqueRandomValues[JvmType](
      columnType: ColumnType[JvmType],
      count: Int): Seq[JvmType] = {

    Iterator.iterate(HashSet.empty[JvmType]) { set =>
      set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
    }.drop(count).next().toSeq
  }

  def makeRandomRow(
      head: ColumnType[_],
      tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)

  def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
    val row = new GenericMutableRow(columnTypes.length)
    makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
      row(index) = value
    }
    row
  }

  def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
      columnType: NativeColumnType[T],
      count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = {

    val values = makeUniqueRandomValues(columnType, count)
    val rows = values.map { value =>
      val row = new GenericMutableRow(1)
      row(0) = value
      row
    }

    (values, rows)
  }
} 
Example 41
Source File: NumberConverterSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String

class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }

  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 42
Source File: StringUtils.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.regex.Pattern

import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

  // replace the _ with .{1} exactly match 1 time of any character
  // replace the % with .*, match 0 or more times with any character
  def escapeLikeRegex(v: String): String = {
    if (!v.isEmpty) {
      "(?s)" + (' ' +: v.init).zip(v).flatMap {
        case (prev, '\\') => ""
        case ('\\', c) =>
          c match {
            case '_' => "_"
            case '%' => "%"
            case _ => Pattern.quote("\\" + c)
          }
        case (prev, c) =>
          c match {
            case '_' => "."
            case '%' => ".*"
            case _ => Pattern.quote(Character.toString(c))
          }
      }.mkString
    } else {
      v
    }
  }

  private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
  private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)

  def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
  def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
} 
Example 43
Source File: KinesisRecordToUnsafeRowConverter.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import com.amazonaws.services.kinesis.model.Record

import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String

private[kinesis] class KinesisRecordToUnsafeRowConverter {
  private val rowWriter = new UnsafeRowWriter(5)

  def toUnsafeRow(record: Record, streamName: String): UnsafeRow = {
    rowWriter.reset()
    rowWriter.write(0, record.getData.array())
    rowWriter.write(1, UTF8String.fromString(streamName))
    rowWriter.write(2, UTF8String.fromString(record.getPartitionKey))
    rowWriter.write(3, UTF8String.fromString(record.getSequenceNumber))
    rowWriter.write(4, DateTimeUtils.fromJavaTimestamp(
      new java.sql.Timestamp(record.getApproximateArrivalTimestamp.getTime)))
    rowWriter.getRow
  }
} 
Example 44
Source File: PDFDataSource.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.datasource.pdf


import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.v2.reader._
import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.InternalRow
import java.util.{Collections, List => JList, Optional}
import org.apache.spark.unsafe.types.UTF8String
import mimir.exec.spark.datasource.csv.CSVDataSourceReader

class DefaultSource extends DataSourceV2 with ReadSupport {

  def createReader(options: DataSourceOptions) = {
    val path = options.get("path").get
    val pages = options.get("pages").orElse("all")
    val area = Option(options.get("area").orElse(null))
    val hasGrid = options.get("gridLines").orElse("false").toBoolean
    val pdfExtractor = new PDFTableExtractor()
    val outPath = s"${path}.csv"
    pdfExtractor.defaultExtract(path, pages, area, Some(outPath), hasGrid)
    //println(s"------PDFDataSource----$path -> $outPath")
    //println({scala.io.Source.fromFile(outPath).mkString})
    new CSVDataSourceReader(outPath, options.asMap().asScala.toMap)
  }
} 
Example 45
Source File: DataSourceTest.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[sql] abstract class DataSourceTest extends QueryTest {

  protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) {
    test(sqlString) {
      withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) {
        checkAnswer(spark.sql(sqlString), expectedAnswer)
      }
    }
  }

}

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(
      parameters("from").toInt,
      parameters("TO").toInt,
      parameters("Table"))(sqlContext.sparkSession)
  }
}

case class SimpleDDLScan(
    from: Int,
    to: Int,
    table: String)(@transient val sparkSession: SparkSession)
  extends BaseRelation with TableScan {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  override def schema: StructType =
    StructType(Seq(
      StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))

  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sparkSession.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
} 
Example 46
Source File: XmlDataToCatalyst.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlDataToCatalyst(
    child: Expression,
    schema: DataType,
    options: XmlOptions)
  extends UnaryExpression with CodegenFallback with ExpectsInputTypes {

  override lazy val dataType: DataType = schema

  @transient
  lazy val rowSchema: StructType = schema match {
    case st: StructType => st
    case ArrayType(st: StructType, _) => st
  }

  override def nullSafeEval(xml: Any): Any = xml match {
    case string: UTF8String =>
      CatalystTypeConverters.convertToCatalyst(
        StaxXmlParser.parseColumn(string.toString, rowSchema, options))
    case string: String =>
      StaxXmlParser.parseColumn(string, rowSchema, options)
    case arr: GenericArrayData =>
      CatalystTypeConverters.convertToCatalyst(
        arr.array.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options)))
    case arr: Array[_] =>
      arr.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options))
    case _ => null
  }

  override def inputTypes: Seq[DataType] = schema match {
    case _: StructType => Seq(StringType)
    case ArrayType(_: StructType, _) => Seq(ArrayType(StringType))
  }
} 
Example 47
Source File: ColumnarTestUtils.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import scala.collection.immutable.HashSet
import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String

object ColumnarTestUtils {
  def makeNullRow(length: Int): GenericInternalRow = {
    val row = new GenericInternalRow(length)
    (0 until length).foreach(row.setNullAt)
    row
  }

  def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
    def randomBytes(length: Int) = {
      val bytes = new Array[Byte](length)
      Random.nextBytes(bytes)
      bytes
    }

    (columnType match {
      case NULL => null
      case BOOLEAN => Random.nextBoolean()
      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
      case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
      case INT => Random.nextInt()
      case LONG => Random.nextLong()
      case FLOAT => Random.nextFloat()
      case DOUBLE => Random.nextDouble()
      case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
      case BINARY => randomBytes(Random.nextInt(32))
      case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
      case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
      case STRUCT(_) =>
        new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
      case ARRAY(_) =>
        new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
      case MAP(_) =>
        ArrayBasedMapData(
          Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
      case _ => throw new IllegalArgumentException(s"Unknown column type $columnType")
    }).asInstanceOf[JvmType]
  }

  def makeRandomValues(
      head: ColumnType[_],
      tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)

  def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
    columnTypes.map(makeRandomValue(_))
  }

  def makeUniqueRandomValues[JvmType](
      columnType: ColumnType[JvmType],
      count: Int): Seq[JvmType] = {

    Iterator.iterate(HashSet.empty[JvmType]) { set =>
      set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
    }.drop(count).next().toSeq
  }

  def makeRandomRow(
      head: ColumnType[_],
      tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)

  def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
    val row = new GenericInternalRow(columnTypes.length)
    makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
      row(index) = value
    }
    row
  }

  def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
      columnType: NativeColumnType[T],
      count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = {

    val values = makeUniqueRandomValues(columnType, count)
    val rows = values.map { value =>
      val row = new GenericInternalRow(1)
      row(0) = value
      row
    }

    (values, rows)
  }
} 
Example 48
Source File: FailureSafeParser.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class FailureSafeParser[IN](
    rawParser: IN => Seq[InternalRow],
    mode: ParseMode,
    schema: StructType,
    columnNameOfCorruptRecord: String) {

  private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord)
  private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord))
  private val resultRow = new GenericInternalRow(schema.length)
  private val nullResult = new GenericInternalRow(schema.length)

  // This function takes 2 parameters: an optional partial result, and the bad record. If the given
  // schema doesn't contain a field for corrupted record, we just return the partial result or a
  // row with all fields null. If the given schema contains a field for corrupted record, we will
  // set the bad record to this field, and set other fields according to the partial result or null.
  private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = {
    if (corruptFieldIndex.isDefined) {
      (row, badRecord) => {
        var i = 0
        while (i < actualSchema.length) {
          val from = actualSchema(i)
          resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull
          i += 1
        }
        resultRow(corruptFieldIndex.get) = badRecord()
        resultRow
      }
    } else {
      (row, _) => row.getOrElse(nullResult)
    }
  }

  def parse(input: IN): Iterator[InternalRow] = {
    try {
      rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
    } catch {
      case e: BadRecordException => mode match {
        case PermissiveMode =>
          Iterator(toResultRow(e.partialResult(), e.record))
        case DropMalformedMode =>
          Iterator.empty
        case FailFastMode =>
          throw new SparkException("Malformed records are detected in record parsing. " +
            s"Parse Mode: ${FailFastMode.name}.", e.cause)
      }
    }
  }
} 
Example 49
Source File: NumberConverterSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String

class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }

  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 50
Source File: SortOrderExpressionsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.util.TimeZone

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._

class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SortPrefix") {
    val b1 = Literal.create(false, BooleanType)
    val b2 = Literal.create(true, BooleanType)
    val i1 = Literal.create(20132983, IntegerType)
    val i2 = Literal.create(-20132983, IntegerType)
    val l1 = Literal.create(20132983, LongType)
    val l2 = Literal.create(-20132983, LongType)
    val millis = 1524954911000L;
    // Explicitly choose a time zone, since Date objects can create different values depending on
    // local time zone of the machine on which the test is running
    val oldDefaultTZ = TimeZone.getDefault
    val d1 = try {
      TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
      Literal.create(new java.sql.Date(millis), DateType)
    } finally {
      TimeZone.setDefault(oldDefaultTZ)
    }
    val t1 = Literal.create(new Timestamp(millis), TimestampType)
    val f1 = Literal.create(0.7788229f, FloatType)
    val f2 = Literal.create(-0.7788229f, FloatType)
    val db1 = Literal.create(0.7788229d, DoubleType)
    val db2 = Literal.create(-0.7788229d, DoubleType)
    val s1 = Literal.create("T", StringType)
    val s2 = Literal.create("This is longer than 8 characters", StringType)
    val bin1 = Literal.create(Array[Byte](12), BinaryType)
    val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]),
      BinaryType)
    val dec1 = Literal(Decimal(20132983L, 10, 2))
    val dec2 = Literal(Decimal(20132983L, 19, 2))
    val dec3 = Literal(Decimal(20132983L, 21, 2))
    val list1 = Literal(List(1, 2), ArrayType(IntegerType))
    val nullVal = Literal.create(null, IntegerType)

    checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)
    checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L)
    checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L)
    checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L)
    checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L)
    checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L)
    // For some reason, the Literal.create code gives us the number of days since the epoch
    checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L)
    checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000)
    checkEvaluation(SortPrefix(SortOrder(f1, Ascending)),
      DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble))
    checkEvaluation(SortPrefix(SortOrder(f2, Ascending)),
      DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble))
    checkEvaluation(SortPrefix(SortOrder(db1, Ascending)),
      DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double]))
    checkEvaluation(SortPrefix(SortOrder(db2, Ascending)),
      DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double]))
    checkEvaluation(SortPrefix(SortOrder(s1, Ascending)),
      StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String]))
    checkEvaluation(SortPrefix(SortOrder(s2, Ascending)),
      StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String]))
    checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)),
      BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]]))
    checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)),
      BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]]))
    checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L)
    checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L)
    checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)),
      DoublePrefixComparator.computePrefix(201329.83d))
    checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L)
    checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null)
  }
} 
Example 51
Source File: NumberConverter.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.unsafe.types.UTF8String

object NumberConverter {

  
  def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = {
    if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
        || Math.abs(toBase) < Character.MIN_RADIX
        || Math.abs(toBase) > Character.MAX_RADIX) {
      return null
    }

    if (n.length == 0) {
      return null
    }

    var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)

    // Copy the digits in the right side of the array
    val temp = new Array[Byte](64)
    var i = 1
    while (i <= n.length - first) {
      temp(temp.length - i) = n(n.length - i)
      i += 1
    }
    char2byte(fromBase, temp.length - n.length + first, temp)

    // Do the conversion by going through a 64 bit integer
    var v = encode(fromBase, temp.length - n.length + first, temp)
    if (negative && toBase > 0) {
      if (v < 0) {
        v = -1
      } else {
        v = -v
      }
    }
    if (toBase < 0 && v < 0) {
      v = -v
      negative = true
    }
    decode(v, Math.abs(toBase), temp)

    // Find the first non-zero digit or the last digits if all are zero.
    val firstNonZeroPos = {
      val firstNonZero = temp.indexWhere( _ != 0)
      if (firstNonZero != -1) firstNonZero else temp.length - 1
    }
    byte2char(Math.abs(toBase), firstNonZeroPos, temp)

    var resultStartPos = firstNonZeroPos
    if (negative && toBase < 0) {
      resultStartPos = firstNonZeroPos - 1
      temp(resultStartPos) = '-'
    }
    UTF8String.fromBytes(java.util.Arrays.copyOfRange(temp, resultStartPos, temp.length))
  }
} 
Example 52
Source File: StringUtils.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.regex.{Pattern, PatternSyntaxException}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

  
  def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
    val funcNames = scala.collection.mutable.SortedSet.empty[String]
    pattern.trim().split("\\|").foreach { subPattern =>
      try {
        val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
        funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
      } catch {
        case _: PatternSyntaxException =>
      }
    }
    funcNames.toSeq
  }
} 
Example 53
Source File: CreateJacksonParser.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.json

import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}
import java.nio.channels.Channels
import java.nio.charset.Charset

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text
import sun.nio.cs.StreamDecoder

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String

private[sql] object CreateJacksonParser extends Serializable {
  def string(jsonFactory: JsonFactory, record: String): JsonParser = {
    jsonFactory.createParser(record)
  }

  def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
    val bb = record.getByteBuffer
    assert(bb.hasArray)

    val bain = new ByteArrayInputStream(
      bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())

    jsonFactory.createParser(new InputStreamReader(bain, "UTF-8"))
  }

  def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
    jsonFactory.createParser(record.getBytes, 0, record.getLength)
  }

  // Jackson parsers can be ranked according to their performance:
  // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser
  //    but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically
  //    by checking leading bytes of the array.
  // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected
  //    automatically by analyzing first bytes of the input stream.
  // 3. Reader based parser. This is the slowest parser used here but it allows to create
  //    a reader with specific encoding.
  // The method creates a reader for an array with given encoding and sets size of internal
  // decoding buffer according to size of input array.
  private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = {
    val bais = new ByteArrayInputStream(in, 0, length)
    val byteChannel = Channels.newChannel(bais)
    val decodingBufferSize = Math.min(length, 8192)
    val decoder = Charset.forName(enc).newDecoder()

    StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize)
  }

  def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = {
    val sd = getStreamDecoder(enc, record.getBytes, record.getLength)
    jsonFactory.createParser(sd)
  }

  def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = {
    jsonFactory.createParser(is)
  }

  def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
    jsonFactory.createParser(new InputStreamReader(is, enc))
  }

  def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
    val ba = row.getBinary(0)

    jsonFactory.createParser(ba, 0, ba.length)
  }

  def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
    val binary = row.getBinary(0)
    val sd = getStreamDecoder(enc, binary, binary.length)

    jsonFactory.createParser(sd)
  }
} 
Example 54
Source File: misc.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.UUID

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


@ExpressionDescription(
  usage = "_FUNC_() - Returns the current database.",
  examples = """
    Examples:
      > SELECT _FUNC_();
       default
  """)
case class CurrentDatabase() extends LeafExpression with Unevaluable {
  override def dataType: DataType = StringType
  override def foldable: Boolean = true
  override def nullable: Boolean = false
  override def prettyName: String = "current_database"
}

// scalastyle:off line.size.limit
@ExpressionDescription(
  usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""",
  examples = """
    Examples:
      > SELECT _FUNC_();
       46707d92-02f4-4817-8116-a4c3b23e6266
  """,
  note = "The function is non-deterministic.")
// scalastyle:on line.size.limit
case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful
    with ExpressionWithRandomSeed {

  def this() = this(None)

  override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))

  override lazy val resolved: Boolean = randomSeed.isDefined

  override def nullable: Boolean = false

  override def dataType: DataType = StringType

  @transient private[this] var randomGenerator: RandomUUIDGenerator = _

  override protected def initializeInternal(partitionIndex: Int): Unit =
    randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex)

  override protected def evalInternal(input: InternalRow): Any =
    randomGenerator.getNextUUIDUTF8String()

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val randomGen = ctx.freshName("randomGen")
    ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen,
      forceInline = true,
      useFreshName = false)
    ctx.addPartitionInitializationStatement(s"$randomGen = " +
      "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" +
      s"${randomSeed.get}L + partitionIndex);")
    ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
      isNull = FalseLiteral)
  }

  override def freshCopy(): Uuid = Uuid(randomSeed)
} 
Example 55
Source File: inputFileBlock.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String


@ExpressionDescription(
  usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.")
case class InputFileName() extends LeafExpression with Nondeterministic {

  override def nullable: Boolean = false

  override def dataType: DataType = StringType

  override def prettyName: String = "input_file_name"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): UTF8String = {
    InputFileBlockHolder.getInputFilePath
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
      isNull = FalseLiteral)
  }
}


@ExpressionDescription(
  usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.")
case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override def prettyName: String = "input_file_block_start"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): Long = {
    InputFileBlockHolder.getStartOffset
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
  }
}


@ExpressionDescription(
  usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.")
case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override def prettyName: String = "input_file_block_length"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): Long = {
    InputFileBlockHolder.getLength
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
  }
} 
Example 56
Source File: InternalRow.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


  def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
    case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
    case ByteType => (input, ordinal) => input.getByte(ordinal)
    case ShortType => (input, ordinal) => input.getShort(ordinal)
    case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
    case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
    case FloatType => (input, ordinal) => input.getFloat(ordinal)
    case DoubleType => (input, ordinal) => input.getDouble(ordinal)
    case StringType => (input, ordinal) => input.getUTF8String(ordinal)
    case BinaryType => (input, ordinal) => input.getBinary(ordinal)
    case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
    case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
    case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
    case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
    case _: MapType => (input, ordinal) => input.getMap(ordinal)
    case u: UserDefinedType[_] => getAccessor(u.sqlType)
    case _ => (input, ordinal) => input.get(ordinal, dataType)
  }
} 
Example 57
Source File: DeltaRetentionSuiteBase.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta

import java.io.File

import org.apache.spark.sql.delta.DeltaOperations.Truncate
import org.apache.spark.sql.delta.actions.Metadata
import org.apache.spark.sql.delta.util.FileNames
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.unsafe.types.UTF8String

trait DeltaRetentionSuiteBase extends QueryTest
  with SharedSparkSession {
  protected val testOp = Truncate()

  protected override def sparkConf: SparkConf = super.sparkConf
    // Disable the log cleanup because it runs asynchronously and causes test flakiness
    .set("spark.databricks.delta.properties.defaults.enableExpiredLogCleanup", "false")

  protected def intervalStringToMillis(str: String): Long = {
    DeltaConfigs.getMilliSeconds(
      IntervalUtils.safeStringToInterval(UTF8String.fromString(str)))
  }

  protected def getDeltaFiles(dir: File): Seq[File] =
    dir.listFiles().filter(_.getName.endsWith(".json"))

  protected def getCheckpointFiles(dir: File): Seq[File] =
    dir.listFiles().filter(f => FileNames.isCheckpointFile(new Path(f.getCanonicalPath)))

  protected def getLogFiles(dir: File): Seq[File]

  
  protected def startTxnWithManualLogCleanup(log: DeltaLog): OptimisticTransaction = {
    val txn = log.startTransaction()
    // This will pick up `spark.databricks.delta.properties.defaults.enableExpiredLogCleanup` to
    // disable log cleanup.
    txn.updateMetadata(Metadata())
    txn
  }

  test("startTxnWithManualLogCleanup") {
    withTempDir { tempDir =>
      val log = DeltaLog(spark, new Path(tempDir.getCanonicalPath))
      startTxnWithManualLogCleanup(log).commit(Nil, testOp)
      assert(!log.enableExpiredLogCleanup)
    }
  }
} 
Example 58
Source File: DataConverter.scala    From spark-cdm   with MIT License 5 votes vote down vote up
package com.microsoft.cdm.utils

import java.text.SimpleDateFormat
import java.util.{Locale, TimeZone}
import java.sql.Timestamp

import org.apache.commons.lang.time.DateUtils
import org.apache.spark.sql.catalyst.util.TimestampFormatter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


class DataConverter() extends Serializable {

  val dateFormatter = new SimpleDateFormat(Constants.SINGLE_DATE_FORMAT)
  val timestampFormatter = TimestampFormatter(Constants.TIMESTAMP_FORMAT, TimeZone.getTimeZone("UTC"))


  val toSparkType: Map[CDMDataType.Value, DataType] = Map(
    CDMDataType.int64 -> LongType,
    CDMDataType.dateTime -> DateType,
    CDMDataType.string -> StringType,
    CDMDataType.double -> DoubleType,
    CDMDataType.decimal -> DecimalType(Constants.DECIMAL_PRECISION,0),
    CDMDataType.boolean -> BooleanType,
    CDMDataType.dateTimeOffset -> TimestampType
  )

  def jsonToData(dt: DataType, value: String): Any = {
    return dt match {
      case LongType => value.toLong
      case DoubleType => value.toDouble
      case DecimalType() => Decimal(value)
      case BooleanType => value.toBoolean
      case DateType => dateFormatter.parse(value)
      case TimestampType => timestampFormatter.parse(value)
      case _ => UTF8String.fromString(value)
    }
  }

  def toCdmType(dt: DataType): CDMDataType.Value = {
    return dt match {
      case IntegerType => CDMDataType.int64
      case LongType => CDMDataType.int64
      case DateType => CDMDataType.dateTime
      case StringType => CDMDataType.string
      case DoubleType => CDMDataType.double
      case DecimalType() => CDMDataType.decimal
      case BooleanType => CDMDataType.boolean
      case TimestampType => CDMDataType.dateTimeOffset
    }
  }  

  def dataToString(data: Any, dataType: DataType): String = {
    (dataType, data) match {
      case (_, null) => null
      case (DateType, _) => dateFormatter.format(data)
      case (TimestampType, v: Number) => timestampFormatter.format(data.asInstanceOf[Long])
      case _ => data.toString
    }
  }

} 
Example 59
Source File: CardinalityHashFunctionTest.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package com.swoop.alchemy.spark.expressions.hll

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.scalatest.{Matchers, WordSpec}


class CardinalityHashFunctionTest extends WordSpec with Matchers {

  "Cardinality hash functions" should {
    "account for nulls" in {
      val a = UTF8String.fromString("a")

      allDistinct(Seq(
        null,
        Array.empty[Byte],
        Array.apply(1.toByte)
      ), BinaryType)

      allDistinct(Seq(
        null,
        UTF8String.fromString(""),
        a
      ), StringType)

      allDistinct(Seq(
        null,
        ArrayData.toArrayData(Array.empty),
        ArrayData.toArrayData(Array(null)),
        ArrayData.toArrayData(Array(null, null)),
        ArrayData.toArrayData(Array(a, null)),
        ArrayData.toArrayData(Array(null, a))
      ), ArrayType(StringType))


      allDistinct(Seq(
        null,
        ArrayBasedMapData(Map.empty),
        ArrayBasedMapData(Map(null.asInstanceOf[String] -> null))
      ), MapType(StringType, StringType))

      allDistinct(Seq(
        null,
        InternalRow(null),
        InternalRow(a)
      ), new StructType().add("foo", StringType))

      allDistinct(Seq(
        InternalRow(null, a),
        InternalRow(a, null)
      ), new StructType().add("foo", StringType).add("bar", StringType))
    }
  }

  def allDistinct(values: Seq[Any], dataType: DataType): Unit = {
    val hashed = values.map(x => CardinalityXxHash64Function.hash(x, dataType, 0))
    hashed.distinct.length should be(hashed.length)
  }

} 
Example 60
Source File: ColumnarTestUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import scala.collection.immutable.HashSet
import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String

object ColumnarTestUtils {
  def makeNullRow(length: Int): GenericInternalRow = {
    val row = new GenericInternalRow(length)
    (0 until length).foreach(row.setNullAt)
    row
  }

  def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
    def randomBytes(length: Int) = {
      val bytes = new Array[Byte](length)
      Random.nextBytes(bytes)
      bytes
    }

    (columnType match {
      case NULL => null
      case BOOLEAN => Random.nextBoolean()
      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
      case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
      case INT => Random.nextInt()
      case LONG => Random.nextLong()
      case FLOAT => Random.nextFloat()
      case DOUBLE => Random.nextDouble()
      case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
      case BINARY => randomBytes(Random.nextInt(32))
      case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
      case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
      case STRUCT(_) =>
        new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
      case ARRAY(_) =>
        new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
      case MAP(_) =>
        ArrayBasedMapData(
          Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
      case _ => throw new IllegalArgumentException(s"Unknown column type $columnType")
    }).asInstanceOf[JvmType]
  }

  def makeRandomValues(
      head: ColumnType[_],
      tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)

  def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
    columnTypes.map(makeRandomValue(_))
  }

  def makeUniqueRandomValues[JvmType](
      columnType: ColumnType[JvmType],
      count: Int): Seq[JvmType] = {

    Iterator.iterate(HashSet.empty[JvmType]) { set =>
      set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
    }.drop(count).next().toSeq
  }

  def makeRandomRow(
      head: ColumnType[_],
      tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)

  def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
    val row = new GenericInternalRow(columnTypes.length)
    makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
      row(index) = value
    }
    row
  }

  def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
      columnType: NativeColumnType[T],
      count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = {

    val values = makeUniqueRandomValues(columnType, count)
    val rows = values.map { value =>
      val row = new GenericInternalRow(1)
      row(0) = value
      row
    }

    (values, rows)
  }
} 
Example 61
Source File: NumberConverterSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String

class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }

  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 62
Source File: MapDataSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 63
Source File: NumberConverter.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.unsafe.types.UTF8String

object NumberConverter {

  
  def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = {
    if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
        || Math.abs(toBase) < Character.MIN_RADIX
        || Math.abs(toBase) > Character.MAX_RADIX) {
      return null
    }

    if (n.length == 0) {
      return null
    }

    var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)

    // Copy the digits in the right side of the array
    val temp = new Array[Byte](64)
    var i = 1
    while (i <= n.length - first) {
      temp(temp.length - i) = n(n.length - i)
      i += 1
    }
    char2byte(fromBase, temp.length - n.length + first, temp)

    // Do the conversion by going through a 64 bit integer
    var v = encode(fromBase, temp.length - n.length + first, temp)
    if (negative && toBase > 0) {
      if (v < 0) {
        v = -1
      } else {
        v = -v
      }
    }
    if (toBase < 0 && v < 0) {
      v = -v
      negative = true
    }
    decode(v, Math.abs(toBase), temp)

    // Find the first non-zero digit or the last digits if all are zero.
    val firstNonZeroPos = {
      val firstNonZero = temp.indexWhere( _ != 0)
      if (firstNonZero != -1) firstNonZero else temp.length - 1
    }
    byte2char(Math.abs(toBase), firstNonZeroPos, temp)

    var resultStartPos = firstNonZeroPos
    if (negative && toBase < 0) {
      resultStartPos = firstNonZeroPos - 1
      temp(resultStartPos) = '-'
    }
    UTF8String.fromBytes(java.util.Arrays.copyOfRange(temp, resultStartPos, temp.length))
  }
} 
Example 64
Source File: StringUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.regex.{Pattern, PatternSyntaxException}

import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

  // replace the _ with .{1} exactly match 1 time of any character
  // replace the % with .*, match 0 or more times with any character
  def escapeLikeRegex(v: String): String = {
    if (!v.isEmpty) {
      "(?s)" + (' ' +: v.init).zip(v).flatMap {
        case (prev, '\\') => ""
        case ('\\', c) =>
          c match {
            case '_' => "_"
            case '%' => "%"
            case _ => Pattern.quote("\\" + c)
          }
        case (prev, c) =>
          c match {
            case '_' => "."
            case '%' => ".*"
            case _ => Pattern.quote(Character.toString(c))
          }
      }.mkString
    } else {
      v
    }
  }

  private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
  private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)

  def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
  def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)

  
  def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
    val funcNames = scala.collection.mutable.SortedSet.empty[String]
    pattern.trim().split("\\|").foreach { subPattern =>
      try {
        val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
        funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
      } catch {
        case _: PatternSyntaxException =>
      }
    }
    funcNames.toSeq
  }
} 
Example 65
Source File: InputFileName.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.rdd.InputFileNameHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String


@ExpressionDescription(
  usage = "_FUNC_() - Returns the name of the current file being read if available",
  extended = "> SELECT _FUNC_();\n ''")
case class InputFileName() extends LeafExpression with Nondeterministic {

  override def nullable: Boolean = true

  override def dataType: DataType = StringType

  override def prettyName: String = "input_file_name"

  override protected def initInternal(): Unit = {}

  override protected def evalInternal(input: InternalRow): UTF8String = {
    InputFileNameHolder.getInputFileName()
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
      "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false")
  }
} 
Example 66
Source File: HashingTF.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.feature

import java.lang.{Iterable => JavaIterable}

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.unsafe.hash.Murmur3_x86_32._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils


  private[spark] def murmur3Hash(term: Any): Int = {
    term match {
      case null => seed
      case b: Boolean => hashInt(if (b) 1 else 0, seed)
      case b: Byte => hashInt(b, seed)
      case s: Short => hashInt(s, seed)
      case i: Int => hashInt(i, seed)
      case l: Long => hashLong(l, seed)
      case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
      case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
      case s: String =>
        val utf8 = UTF8String.fromString(s)
        hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
      case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " +
        s"support type ${term.getClass.getCanonicalName} of input data.")
    }
  }
} 
Example 67
Source File: StringUtils.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.regex.{Pattern, PatternSyntaxException}

import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

  // replace the _ with .{1} exactly match 1 time of any character
  // replace the % with .*, match 0 or more times with any character
  def escapeLikeRegex(v: String): String = {
    if (!v.isEmpty) {
      "(?s)" + (' ' +: v.init).zip(v).flatMap {
        case (prev, '\\') => ""
        case ('\\', c) =>
          c match {
            case '_' => "_"
            case '%' => "%"
            case _ => Pattern.quote("\\" + c)
          }
        case (prev, c) =>
          c match {
            case '_' => "."
            case '%' => ".*"
            case _ => Pattern.quote(Character.toString(c))
          }
      }.mkString
    } else {
      v
    }
  }

  private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
  private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)

  def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
  def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)

  
  def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
    val funcNames = scala.collection.mutable.SortedSet.empty[String]
    pattern.trim().split("\\|").foreach { subPattern =>
      try {
        val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
        funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
      } catch {
        case _: PatternSyntaxException =>
      }
    }
    funcNames.toSeq
  }
} 
Example 68
Source File: SparkDateTime.scala    From spark-datetime   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sparklinedata.datetime

import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

@SQLUserDefinedType(udt = classOf[SparkDateTimeUDT])
case class SparkDateTime(millis : Long, tzId : String)

class SparkDateTimeUDT extends UserDefinedType[SparkDateTime] {

  override def sqlType: DataType =
    StructType(Seq(StructField("millis", LongType), StructField("tz", StringType)))

  override def serialize(obj: SparkDateTime): InternalRow = {
    obj match {
      case dt: SparkDateTime =>
        val row = new GenericMutableRow(2)
        row.setLong(0, dt.millis)
        row.update(1, CatalystTypeConverters.convertToCatalyst(dt.tzId))
        row
    }
  }

  override def deserialize(datum: Any): SparkDateTime = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 2,
          s"SparkDateTimeUDT.deserialize given row with length ${row.numFields} " +
            s"but requires length == 2")
        SparkDateTime(row.getLong(0), row.getString(1))
    }
  }

  override def userClass: Class[SparkDateTime] = classOf[SparkDateTime]

  override def asNullable: SparkDateTimeUDT = this
}

@SQLUserDefinedType(udt = classOf[SparkPeriodUDT])
case class SparkPeriod(periodIsoStr : String)

class SparkPeriodUDT extends UserDefinedType[SparkPeriod] {

  override def sqlType: DataType = StringType


  override def serialize(obj: SparkPeriod): Any = {
    obj match {
      case p: SparkPeriod =>
        CatalystTypeConverters.convertToCatalyst(p.periodIsoStr)
    }
  }

  override def deserialize(datum: Any): SparkPeriod = {
    datum match {
      case s : UTF8String =>
        SparkPeriod(s.toString())
    }
  }

  override def userClass: Class[SparkPeriod] = classOf[SparkPeriod]

  override def asNullable: SparkPeriodUDT = this
}

@SQLUserDefinedType(udt = classOf[SparkIntervalUDT])
case class SparkInterval(intervalIsoStr : String)

class SparkIntervalUDT extends UserDefinedType[SparkInterval] {

  override def sqlType: DataType = StringType


  override def serialize(obj: SparkInterval): Any = {
    obj match {
      case i: SparkInterval =>
        CatalystTypeConverters.convertToCatalyst(i.intervalIsoStr)
    }
  }

  override def deserialize(datum: Any): SparkInterval = {
    datum match {
      case s : UTF8String =>
        SparkInterval(s.toString())
    }
  }

  override def userClass: Class[SparkInterval] = classOf[SparkInterval]

  override def asNullable: SparkIntervalUDT = this
} 
Example 69
Source File: JdbcUtil.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.bahir.sql.streaming.jdbc

import java.sql.{Connection, PreparedStatement}
import java.util.Locale

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


object JdbcUtil {

  def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).getOrElse(
      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
  }

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  def makeSetter(
    conn: Connection,
    dialect: JdbcDialect,
    dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val strValue = row.get(pos) match {
          case str: UTF8String => str.toString
          case str: String => str
        }
        stmt.setString(pos + 1, strValue)

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase(Locale.ROOT).split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos).toArray)
        stmt.setArray(pos + 1, array)

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }
} 
Example 70
Source File: JavaConverter.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb.catalyst

import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.JavaConverters._

object JavaConverter {

    def convertRowValue(row: InternalRow, index: Int, elementType: DataType): Any = {
        elementType match {
            case ArrayType(innerType, _) => convertArray(row.getArray(index), innerType)
            case MapType(keyType, valueType, _) => convertMap(row.getMap(index), keyType, valueType)
            case StructType(fields) => convertStruct(row.getStruct(index, fields.length), fields)
            case StringType => row.getString(index)
            case _ => row.get(index, elementType)
        }
    }

    def convertArray(array: ArrayData, elementType: DataType): Any = {
        elementType match {
            case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(convertArray(_, innerType)).asJava
            case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(convertMap(_, keyType, valueType)).asJava
            case structType: StructType => array.toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)).asJava
            case StringType => convertStringArray(array).asJava
            case _ => array.toSeq[Any](elementType).asJava
        }
    }

    def convertMap(map: MapData, keyType: DataType, valueType: DataType): util.Map[String, Any] = {
        if (keyType != StringType) throw new IllegalArgumentException(
            s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
        val keys = convertStringArray(map.keyArray())
        val values = valueType match {
            case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(convertArray(_, innerType))
            case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(convertMap(_, innerKeyType, innerValueType))
            case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(convertStruct(_, structType.fields))
            case StringType => convertStringArray(map.valueArray())
            case _ => map.valueArray().toSeq[Any](valueType)
        }
        val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i)
        Map(kvPairs: _*).asJava
    }

    def convertStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = {
        val kvPairs = for (i <- 0 until row.numFields) yield
            if (row.isNullAt(i)) fields(i).name -> null
            else fields(i).name -> convertRowValue(row, i, fields(i).dataType)
        Map(kvPairs: _*).asJava
    }

    def convertStringArray(array: ArrayData): Seq[String] =
        array.toSeq[UTF8String](StringType).map(_.toString)

} 
Example 71
Source File: TypeConversion.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb.datasource

import com.amazonaws.services.dynamodbv2.document.{IncompatibleTypeException, Item}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.JavaConverters._

private[dynamodb] object TypeConversion {

    def apply(attrName: String, sparkType: DataType): Item => Any =

        sparkType match {
            case BooleanType => nullableGet(_.getBOOL)(attrName)
            case StringType => nullableGet(item => attrName => UTF8String.fromString(item.getString(attrName)))(attrName)
            case IntegerType => nullableGet(_.getInt)(attrName)
            case LongType => nullableGet(_.getLong)(attrName)
            case DoubleType => nullableGet(_.getDouble)(attrName)
            case FloatType => nullableGet(_.getFloat)(attrName)
            case BinaryType => nullableGet(_.getBinary)(attrName)
            case DecimalType() => nullableGet(_.getNumber)(attrName)
            case ArrayType(innerType, _) =>
                nullableGet(_.getList)(attrName).andThen(extractArray(convertValue(innerType)))
            case MapType(keyType, valueType, _) =>
                if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
                nullableGet(_.getRawMap)(attrName).andThen(extractMap(convertValue(valueType)))
            case StructType(fields) =>
                val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) })
                nullableGet(_.getRawMap)(attrName).andThen(extractStruct(nestedConversions))
            case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.")
        }

    private val stringConverter = (value: Any) => UTF8String.fromString(value.asInstanceOf[String])

    private def convertValue(sparkType: DataType): Any => Any =

        sparkType match {
            case IntegerType => nullableConvert(_.intValue())
            case LongType => nullableConvert(_.longValue())
            case DoubleType => nullableConvert(_.doubleValue())
            case FloatType => nullableConvert(_.floatValue())
            case DecimalType() => nullableConvert(identity)
            case ArrayType(innerType, _) => extractArray(convertValue(innerType))
            case MapType(keyType, valueType, _) =>
                if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
                extractMap(convertValue(valueType))
            case StructType(fields) =>
                val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) })
                extractStruct(nestedConversions)
            case BooleanType => {
                case boolean: Boolean => boolean
                case _ => null
            }
            case StringType => {
                case string: String => UTF8String.fromString(string)
                case _ => null
            }
            case BinaryType => {
                case byteArray: Array[Byte] => byteArray
                case _ => null
            }
            case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.")
        }

    private def nullableGet(getter: Item => String => Any)(attrName: String): Item => Any = {
        case item if item.hasAttribute(attrName) => try getter(item)(attrName) catch {
            case _: NumberFormatException => null
            case _: IncompatibleTypeException => null
        }
        case _ => null
    }

    private def nullableConvert(converter: java.math.BigDecimal => Any): Any => Any = {
        case item: java.math.BigDecimal => converter(item)
        case _ => null
    }

    private def extractArray(converter: Any => Any): Any => Any = {
        case list: java.util.List[_] => new GenericArrayData(list.asScala.map(converter))
        case set: java.util.Set[_] => new GenericArrayData(set.asScala.map(converter).toSeq)
        case _ => null
    }

    private def extractMap(converter: Any => Any): Any => Any = {
        case map: java.util.Map[_, _] => ArrayBasedMapData(map, stringConverter, converter)
        case _ => null
    }

    private def extractStruct(conversions: Seq[(String, Any => Any)]): Any => Any = {
        case map: java.util.Map[_, _] => InternalRow.fromSeq(conversions.map({
            case (name, conv) => conv(map.get(name))
        }))
        case _ => null
    }

} 
Example 72
Source File: RowSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 73
Source File: ColumnarTestUtils.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import scala.collection.immutable.HashSet
import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String

object ColumnarTestUtils {
  def makeNullRow(length: Int): GenericInternalRow = {
    val row = new GenericInternalRow(length)
    (0 until length).foreach(row.setNullAt)
    row
  }

  def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
    def randomBytes(length: Int) = {
      val bytes = new Array[Byte](length)
      Random.nextBytes(bytes)
      bytes
    }

    (columnType match {
      case NULL => null
      case BOOLEAN => Random.nextBoolean()
      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
      case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
      case INT => Random.nextInt()
      case LONG => Random.nextLong()
      case FLOAT => Random.nextFloat()
      case DOUBLE => Random.nextDouble()
      case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
      case BINARY => randomBytes(Random.nextInt(32))
      case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
      case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
      case STRUCT(_) =>
        new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
      case ARRAY(_) =>
        new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
      case MAP(_) =>
        ArrayBasedMapData(
          Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
      case _ => throw new IllegalArgumentException(s"Unknown column type $columnType")
    }).asInstanceOf[JvmType]
  }

  def makeRandomValues(
      head: ColumnType[_],
      tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)

  def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
    columnTypes.map(makeRandomValue(_))
  }

  def makeUniqueRandomValues[JvmType](
      columnType: ColumnType[JvmType],
      count: Int): Seq[JvmType] = {

    Iterator.iterate(HashSet.empty[JvmType]) { set =>
      set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
    }.drop(count).next().toSeq
  }

  def makeRandomRow(
      head: ColumnType[_],
      tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)

  def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
    val row = new GenericInternalRow(columnTypes.length)
    makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
      row(index) = value
    }
    row
  }

  def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
      columnType: NativeColumnType[T],
      count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = {

    val values = makeUniqueRandomValues(columnType, count)
    val rows = values.map { value =>
      val row = new GenericInternalRow(1)
      row(0) = value
      row
    }

    (values, rows)
  }
} 
Example 74
Source File: NumberConverterSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String

class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }

  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 75
Source File: MapDataSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 76
Source File: DictionaryBasedEncoderSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.io

import org.apache.parquet.bytes.BytesInput
import org.apache.parquet.column.page.DictionaryPage
import org.apache.parquet.column.values.dictionary.PlainValuesDictionary.PlainBinaryDictionary
import org.scalacheck.{Arbitrary, Gen, Properties}
import org.scalacheck.Prop.forAll
import org.scalatest.prop.Checkers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.oap.adapter.PropertiesAdapter
import org.apache.spark.sql.execution.datasources.oap.filecache.StringFiberBuilder
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

class DictionaryBasedEncoderCheck extends Properties("DictionaryBasedEncoder") {
  private val rowCountInEachGroup = Gen.choose(1, 1024)
  private val rowCountInLastGroup = Gen.choose(1, 1024)
  private val groupCount = Gen.choose(1, 100)

  property("Encoding/Decoding String Type") = forAll { (values: Array[String]) =>

    forAll(rowCountInEachGroup, rowCountInLastGroup, groupCount) {
      (rowCount, lastCount, groupCount) =>
        if (values.nonEmpty) {
          // This is the 'PLAIN' FiberBuilder to validate the 'Encoding/Decoding'
          // Normally, the test case should be:
          // values => encoded bytes => decoded bytes => decoded values (Using ColumnValues class)
          // Validate if 'values' and 'decoded values' are identical.
          // But ColumnValues only support read value form DataFile. So, we have to use another way
          // to validate.
          val referenceFiberBuilder = StringFiberBuilder(rowCount, 0)
          val fiberBuilder = PlainBinaryDictionaryFiberBuilder(rowCount, 0, StringType)
          !(0 until groupCount).exists { group =>
            // If lastCount > rowCount, assume lastCount = rowCount
            val count =
              if (group < groupCount - 1) {
                rowCount
              } else if (lastCount > rowCount) {
                rowCount
              } else {
                lastCount
              }
            (0 until count).foreach { row =>
              fiberBuilder.append(InternalRow(UTF8String.fromString(values(row % values.length))))
              referenceFiberBuilder
                .append(InternalRow(UTF8String.fromString(values(row % values.length))))
            }
            val bytes = fiberBuilder.build().fiberData
            val dictionary = new PlainBinaryDictionary(
              new DictionaryPage(
                BytesInput.from(fiberBuilder.buildDictionary),
                fiberBuilder.getDictionarySize,
                org.apache.parquet.column.Encoding.PLAIN))
            val fiberParser = PlainDictionaryFiberParser(
              new OapDataFileMetaV1(rowCountInEachGroup = rowCount), dictionary, StringType)
            val parsedBytes = fiberParser.parse(bytes, count)
            val referenceBytes = referenceFiberBuilder.build().fiberData
            referenceFiberBuilder.clear()
            referenceFiberBuilder.resetDictionary()
            fiberBuilder.clear()
            fiberBuilder.resetDictionary()
            assert(parsedBytes.length == referenceBytes.length)
            parsedBytes.zip(referenceBytes).exists(byte => byte._1 != byte._2)
          }
        } else {
          true
        }
    }
  }
}

class DictionaryBasedEncoderSuite extends SparkFunSuite with Checkers {

  test("Check Encoding/Decoding") {
    check(PropertiesAdapter.getProp(new DictionaryBasedEncoderCheck()))
  }
} 
Example 77
Source File: DeltaByteArrayEncoderSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.io

import org.scalacheck.{Arbitrary, Gen, Properties}
import org.scalacheck.Prop.forAll
import org.scalatest.prop.Checkers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.oap.adapter.PropertiesAdapter
import org.apache.spark.sql.execution.datasources.oap.filecache.StringFiberBuilder
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

class DeltaByteArrayEncoderCheck extends Properties("DeltaByteArrayEncoder") {

  private val rowCountInEachGroup = Gen.choose(1, 1024)
  private val rowCountInLastGroup = Gen.choose(1, 1024)
  private val groupCount = Gen.choose(1, 100)

  property("Encoding/Decoding String Type") = forAll { (values: Array[String]) =>

    forAll(rowCountInEachGroup, rowCountInLastGroup, groupCount) {
      (rowCount, lastCount, groupCount) =>
        if (values.nonEmpty) {
          // This is the 'PLAIN' FiberBuilder to validate the 'Encoding/Decoding'
          // Normally, the test case should be:
          // values => encoded bytes => decoded bytes => decoded values (Using ColumnValues class)
          // Validate if 'values' and 'decoded values' are identical.
          // But ColumnValues only support read value form DataFile. So, we have to use another way
          // to validate.
          val referenceFiberBuilder = StringFiberBuilder(rowCount, 0)
          val fiberBuilder = DeltaByteArrayFiberBuilder(rowCount, 0, StringType)
          val fiberParser = DeltaByteArrayDataFiberParser(
            new OapDataFileMetaV1(rowCountInEachGroup = rowCount), StringType)
          !(0 until groupCount).exists { group =>
            // If lastCount > rowCount, assume lastCount = rowCount
            val count = if (group < groupCount - 1) {
              rowCount
            } else if (lastCount > rowCount) {
              rowCount
            } else {
              lastCount
            }
            (0 until count).foreach { row =>
              fiberBuilder.append(InternalRow(UTF8String.fromString(values(row % values.length))))
              referenceFiberBuilder
                .append(InternalRow(UTF8String.fromString(values(row % values.length))))
            }
            val bytes = fiberBuilder.build().fiberData
            val parsedBytes = fiberParser.parse(bytes, count)
            val referenceBytes = referenceFiberBuilder.build().fiberData
            referenceFiberBuilder.clear()
            fiberBuilder.clear()
            assert(parsedBytes.length == referenceBytes.length)
            parsedBytes.zip(referenceBytes).exists(byte => byte._1 != byte._2)
          }
        } else true
    }
  }
}

class DeltaByteArrayEncoderSuite extends SparkFunSuite with Checkers {

  test("Check Encoding/Decoding") {
    check(PropertiesAdapter.getProp(new DictionaryBasedEncoderCheck()))
  }
} 
Example 78
Source File: StatisticsTest.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.statistics

import java.io.ByteArrayOutputStream

import scala.collection.mutable.ArrayBuffer

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BaseOrdering
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.datasources.oap.filecache.FiberCache
import org.apache.spark.sql.execution.datasources.oap.index.RangeInterval
import org.apache.spark.sql.execution.datasources.oap.utils.{NonNullKeyReader, NonNullKeyWriter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.unsafe.types.UTF8String

abstract class StatisticsTest extends SparkFunSuite with BeforeAndAfterEach {

  protected def rowGen(i: Int): InternalRow = InternalRow(i, UTF8String.fromString(s"test#$i"))

  protected lazy val schema: StructType = StructType(StructField("a", IntegerType)
    :: StructField("b", StringType) :: Nil)
  @transient
  protected lazy val nnkw: NonNullKeyWriter = new NonNullKeyWriter(schema)
  @transient
  protected lazy val nnkr: NonNullKeyReader = new NonNullKeyReader(schema)
  @transient
  protected lazy val ordering: BaseOrdering = GenerateOrdering.create(schema)
  @transient
  protected lazy val partialOrdering: BaseOrdering =
    GenerateOrdering.create(StructType(schema.dropRight(1)))
  protected var out: ByteArrayOutputStream = _

  protected var intervalArray: ArrayBuffer[RangeInterval] = new ArrayBuffer[RangeInterval]()

  override def beforeEach(): Unit = {
    out = new ByteArrayOutputStream(8000)
  }

  override def afterEach(): Unit = {
    out.close()
    intervalArray.clear()
  }

  protected def generateInterval(
      start: InternalRow, end: InternalRow,
      startInclude: Boolean, endInclude: Boolean): Unit = {
    intervalArray.clear()
    intervalArray.append(new RangeInterval(start, end, startInclude, endInclude))
  }

  protected def checkInternalRow(row1: InternalRow, row2: InternalRow): Unit = {
    val res = row1 == row2 // it works..
    assert(res, s"row1: $row1 does not match $row2")
  }

  protected def wrapToFiberCache(out: ByteArrayOutputStream): FiberCache = {
    val bytes = out.toByteArray
    FiberCache(bytes)
  }
} 
Example 79
Source File: NonNullKeySuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.utils

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.oap.filecache.FiberCache
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ByteBufferOutputStream

class NonNullKeySuite extends SparkFunSuite with Logging {

  private lazy val random = new Random(0)
  private lazy val values = {
    val booleans: Seq[Boolean] = Seq(true, false)
    val bytes: Seq[Byte] = Seq(Byte.MinValue, 0, 10, 30, Byte.MaxValue)
    val shorts: Seq[Short] = Seq(Short.MinValue, -100, 0, 10, 200, Short.MaxValue)
    val ints: Seq[Int] = Seq(Int.MinValue, -100, 0, 100, 12346, Int.MaxValue)
    val longs: Seq[Long] = Seq(Long.MinValue, -10000, 0, 20, Long.MaxValue)
    val floats: Seq[Float] = Seq(Float.MinValue, Float.MinPositiveValue, Float.MaxValue)
    val doubles: Seq[Double] = Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue)
    val strings: Seq[UTF8String] =
      Seq("", "test", "b plus tree", "BTreeRecordReaderWriter").map(UTF8String.fromString)
    val binaries: Seq[Array[Byte]] = (0 until 20 by 5).map{ size =>
      val buf = new Array[Byte](size)
      random.nextBytes(buf)
      buf
    }
    val values = booleans ++ bytes ++ shorts ++ ints ++ longs ++
      floats ++ doubles ++ strings ++ binaries ++ Nil
    random.shuffle(values)
  }
  private def toSparkDataType(any: Any): DataType = {
    any match {
      case _: Boolean => BooleanType
      case _: Short => ShortType
      case _: Byte => ByteType
      case _: Int => IntegerType
      case _: Long => LongType
      case _: Float => FloatType
      case _: Double => DoubleType
      case _: UTF8String => StringType
      case _: Array[Byte] => BinaryType
    }
  }

  test("Read/Write Based On Schema") {
    values.grouped(10).foreach { valueSeq =>
      val schema = StructType(valueSeq.zipWithIndex.map {
        case (v, i) => StructField(s"col$i", toSparkDataType(v))
      })
      val nnkw = new NonNullKeyWriter(schema)
      val nnkr = new NonNullKeyReader(schema)
      val row = InternalRow.fromSeq(valueSeq)
      val buf = new ByteBufferOutputStream()
      nnkw.writeKey(buf, row)
      val answerRow = nnkr.readKey(FiberCache(buf.toByteArray), 0)._1
      assert(row.equals(answerRow))
    }
  }
} 
Example 80
Source File: OapDataReader.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.io

import org.apache.hadoop.fs.FSDataInputStream

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{OapException, PartitionedFile}
import org.apache.spark.sql.execution.datasources.oap.INDEX_STAT._
import org.apache.spark.sql.execution.datasources.oap.OapFileFormat
import org.apache.spark.sql.execution.datasources.oap.io.OapDataFileProperties.DataFileVersion
import org.apache.spark.sql.execution.datasources.oap.io.OapDataFileProperties.DataFileVersion.DataFileVersion
import org.apache.spark.unsafe.types.UTF8String

abstract class OapDataReader {
  def read(file: PartitionedFile): Iterator[InternalRow]

  // The two following fields have to be defined by certain versions of OapDataReader for use in
  // [[OapMetricsManager]]
  def rowsReadByIndex: Option[Long]
  def indexStat: INDEX_STAT
}

object OapDataReader extends Logging {

  def readVersion(is: FSDataInputStream, fileLen: Long): DataFileVersion = {
    val MAGIC_VERSION_LENGTH = 4
    val metaEnd = fileLen - 4

    // seek to the position of data file meta length
    is.seek(metaEnd)
    val metaLength = is.readInt()
    // read all bytes of data file meta
    val magicBuffer = new Array[Byte](MAGIC_VERSION_LENGTH)
    is.readFully(metaEnd - metaLength, magicBuffer)

    val magic = UTF8String.fromBytes(magicBuffer).toString
    magic match {
      case m if ! m.contains("OAP") => throw new OapException("Not a valid Oap Data File")
      case m if m == "OAP1" => DataFileVersion.OAP_DATAFILE_V1
      case _ => throw new OapException("Not a supported Oap Data File version")
    }
  }

  def getDataFileClassFor(dataReaderClassFromDataSourceMeta: String, reader: OapDataReader): String
    = {
    dataReaderClassFromDataSourceMeta match {
      case c if c == OapFileFormat.PARQUET_DATA_FILE_CLASSNAME => c
      case c if c == OapFileFormat.ORC_DATA_FILE_CLASSNAME => c
      case c if c == OapFileFormat.OAP_DATA_FILE_CLASSNAME =>
        reader match {
          case r: OapDataReaderV1 => OapFileFormat.OAP_DATA_FILE_V1_CLASSNAME
          case _ => throw new OapException(s"Undefined connection for $reader")
        }
      case _ => throw new OapException(
        s"Undefined data reader class name $dataReaderClassFromDataSourceMeta")
    }
  }
} 
Example 81
Source File: NodeType.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.types

import java.sql.Date

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String


class NodeType extends UserDefinedType[Node] {

  override val sqlType = StructType(Seq(
    StructField("path", ArrayType(StringType, containsNull = false), nullable = false),
    StructField("dataType", StringType, nullable = false),
    StructField("preRank", IntegerType, nullable = true),
    StructField("postRank", IntegerType, nullable = true),
    StructField("isLeaf", BooleanType, nullable = true),
    StructField("ordPath", ArrayType(LongType, containsNull=false), nullable = true)
  ))

  override def serialize(obj: Any): Any = obj match {
    case node: Node =>
      InternalRow(new GenericArrayData(node.path.map {
        case null => null
        case p => UTF8String.fromString(p.toString)
      }),
        UTF8String.fromString(node.pathDataTypeJson),
        node.preRank,
        node.postRank,
        node.isLeaf,
        if (node.ordPath == null){
          node.ordPath
        } else {
          new GenericArrayData(node.ordPath)
        })
    case _ => throw new UnsupportedOperationException(s"Cannot serialize ${obj.getClass}")
  }

  // scalastyle:off cyclomatic.complexity
  override def deserialize(datum: Any): Node = datum match {
    case row: InternalRow => {
      val stringArray = row.getArray(0).toArray[UTF8String](StringType).map {
        case null => null
        case somethingElse => somethingElse.toString
      }
      val readDataTypeString: String = row.getString(1)
      val readDataType: DataType = DataType.fromJson(readDataTypeString)
      val path: Seq[Any] = readDataType match {
        case StringType => stringArray
        case LongType => stringArray.map(v => if (v != null) v.toLong else null)
        case IntegerType => stringArray.map(v => if (v != null) v.toInt else null)
        case DoubleType => stringArray.map(v => if (v != null) v.toDouble else null)
        case FloatType => stringArray.map(v => if (v != null) v.toFloat else null)
        case ByteType => stringArray.map(v => if (v != null) v.toByte else null)
        case BooleanType => stringArray.map(v => if (v != null) v.toBoolean else null)
        case TimestampType => stringArray.map(v => if (v != null) v.toLong else null)
        case dt: DataType => sys.error(s"Type $dt not supported for hierarchy path")
      }
      val preRank: Integer = if (row.isNullAt(2)) null else row.getInt(2)
      val postRank: Integer = if (row.isNullAt(3)) null else row.getInt(3)
      // scalastyle:off magic.number
      val isLeaf: java.lang.Boolean = if (row.isNullAt(4)) null else row.getBoolean(4)
      val ordPath: Seq[Long] = if (row.isNullAt(5)) null else row.getArray(5).toLongArray()
      // scalastyle:on magic.number
      Node(
        path,
        readDataTypeString,
        preRank,
        postRank,
        isLeaf,
        ordPath
      )
    }
    case node: Node => node
    case _ => throw new UnsupportedOperationException(s"Cannot deserialize ${datum.getClass}")
  }
  // scalastyle:on

  override def userClass: java.lang.Class[Node] = classOf[Node]
}

case object NodeType extends NodeType 
Example 82
Source File: ERPCurrencyConversionExpression.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.currency.erp

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.currency.CurrencyConversionException
import org.apache.spark.sql.currency.erp.ERPConversionLoader.RConversionOptionsCurried
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.util.control.NonFatal



case class ERPCurrencyConversionExpression(
    conversionFunction: RConversionOptionsCurried,
    children: Seq[Expression])
  extends Expression
  with ImplicitCastInputTypes
  with CodegenFallback {

  protected val CLIENT_INDEX = 0
  protected val CONVERSION_TYPE_INDEX = 1
  protected val AMOUNT_INDEX = 2
  protected val FROM_INDEX = 3
  protected val TO_INDEX = 4
  protected val DATE_INDEX = 5
  protected val NUM_ARGS = 6

  protected val errorMessage = "Currency conversion library encountered an internal error"


  override def eval(input: InternalRow): Any = {
    val inputArguments = children.map(_.eval(input))

    require(inputArguments.length == NUM_ARGS, "wrong number of arguments")

    // parse arguments
    val client = Option(inputArguments(CLIENT_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val conversionType =
      Option(inputArguments(CONVERSION_TYPE_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val amount = Option(inputArguments(AMOUNT_INDEX).asInstanceOf[Decimal].toJavaBigDecimal)
    val sourceCurrency =
      Option(inputArguments(FROM_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val targetCurrency = Option(inputArguments(TO_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val date = Option(inputArguments(DATE_INDEX).asInstanceOf[UTF8String]).map(_.toString)

    // perform conversion
    val conversion =
      conversionFunction(client, conversionType, sourceCurrency, targetCurrency, date)
    val resultTry = conversion(amount)

    // If 'resultTry' holds a 'Failure', we have to propagate it because potential failure
    // handling already took place. We just wrap it in case it is a cryptic error.
    resultTry.recover {
      case NonFatal(err) => throw new CurrencyConversionException(errorMessage, err)
    }.get.map(Decimal.apply).orNull
  }

  override def dataType: DataType = DecimalType.forType(DoubleType)

  override def nullable: Boolean = true

  override def inputTypes: Seq[AbstractDataType] =
    Seq(StringType, StringType, DecimalType, StringType, StringType, StringType)

  def inputNames: Seq[String] =
    Seq("client", "conversion_type", "amount", "source", "target", "date")

  def getChild(name: String): Option[Expression] = {
    inputNames.zip(children).find { case (n, _) => name == n }.map(_._2)
  }
} 
Example 83
Source File: AnnotationParser.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.parser

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{AnnotationReference, Expression, Literal}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


  protected def toTableMetadata(metadata: Map[String, Expression]): Metadata = {
    val res = new MetadataBuilder()
    metadata.foreach {
      case (k, v:Literal) =>
        v.dataType match {
          case StringType =>
            if (k.equals("?")) {
              sys.error("column metadata key can not be ?")
            }
            if (k.equals("*")) {
              sys.error("column metadata key can not be *")
            }
            res.putString(k, v.value.asInstanceOf[UTF8String].toString)
          case LongType => res.putLong(k, v.value.asInstanceOf[Long])
          case DoubleType => res.putDouble(k, v.value.asInstanceOf[Double])
          case NullType =>
            res.putString(k, null)
          case a:ArrayType => res.putString(k, v.value.toString)
        }
      case (k, v:AnnotationReference) =>
        sys.error("column metadata can not have a reference to another column metadata")
    }
    res.build()
  }
} 
Example 84
Source File: DataSourceTest.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[sql] abstract class DataSourceTest extends QueryTest {

  protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) {
    test(sqlString) {
      withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) {
        checkAnswer(spark.sql(sqlString), expectedAnswer)
      }
    }
  }

}

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(
      parameters("from").toInt,
      parameters("TO").toInt,
      parameters("Table"))(sqlContext.sparkSession)
  }
}

case class SimpleDDLScan(
    from: Int,
    to: Int,
    table: String)(@transient val sparkSession: SparkSession)
  extends BaseRelation with TableScan {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  override def schema: StructType =
    StructType(Seq(
      StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))

  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sparkSession.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
} 
Example 85
Source File: RowSuite.scala    From XSQL   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 86
Source File: RowSuite.scala    From drizzle-spark   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 87
Source File: CarbonCountStar.scala    From carbondata   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.optimizer.CarbonFilters
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

import org.apache.carbondata.core.datastore.impl.FileFactory
import org.apache.carbondata.core.metadata.AbsoluteTableIdentifier
import org.apache.carbondata.core.metadata.schema.table.CarbonTable
import org.apache.carbondata.core.mutate.CarbonUpdateUtil
import org.apache.carbondata.core.statusmanager.StageInputCollector
import org.apache.carbondata.core.util.{CarbonProperties, ThreadLocalSessionInfo}
import org.apache.carbondata.hadoop.api.{CarbonInputFormat, CarbonTableInputFormat}
import org.apache.carbondata.hadoop.util.CarbonInputFormatUtil
import org.apache.carbondata.spark.load.DataLoadProcessBuilderOnSpark

case class CarbonCountStar(
    attributesRaw: Seq[Attribute],
    carbonTable: CarbonTable,
    sparkSession: SparkSession,
    outUnsafeRows: Boolean = true) extends LeafExecNode {

  override def doExecute(): RDD[InternalRow] = {
    ThreadLocalSessionInfo
      .setConfigurationToCurrentThread(sparkSession.sessionState.newHadoopConf())
    val absoluteTableIdentifier = carbonTable.getAbsoluteTableIdentifier
    val (job, tableInputFormat) = createCarbonInputFormat(absoluteTableIdentifier)
    CarbonInputFormat.setQuerySegment(job.getConfiguration, carbonTable)

    // get row count
    var rowCount = CarbonUpdateUtil.getRowCount(
      tableInputFormat.getBlockRowCount(
        job,
        carbonTable,
        CarbonFilters.getPartitions(
          Seq.empty,
          sparkSession,
          TableIdentifier(
            carbonTable.getTableName,
            Some(carbonTable.getDatabaseName))).map(_.asJava).orNull, false),
      carbonTable)

    if (CarbonProperties.isQueryStageInputEnabled) {
      // check for number of row for stage input
      val splits = StageInputCollector.createInputSplits(carbonTable, job.getConfiguration)
      if (!splits.isEmpty) {
        val df = DataLoadProcessBuilderOnSpark.createInputDataFrame(
          sparkSession, carbonTable, splits.asScala)
        rowCount += df.count()
      }
    }

    val valueRaw =
      attributesRaw.head.dataType match {
        case StringType => Seq(UTF8String.fromString(Long.box(rowCount).toString)).toArray
          .asInstanceOf[Array[Any]]
        case _ => Seq(Long.box(rowCount)).toArray.asInstanceOf[Array[Any]]
      }
    val value = new GenericInternalRow(valueRaw)
    val unsafeProjection = UnsafeProjection.create(output.map(_.dataType).toArray)
    val row = if (outUnsafeRows) unsafeProjection(value) else value
    sparkContext.parallelize(Seq(row))
  }

  override def output: Seq[Attribute] = {
    attributesRaw
  }

  private def createCarbonInputFormat(absoluteTableIdentifier: AbsoluteTableIdentifier
  ): (Job, CarbonTableInputFormat[Array[Object]]) = {
    val carbonInputFormat = new CarbonTableInputFormat[Array[Object]]()
    val jobConf: JobConf = new JobConf(FileFactory.getConfiguration)
    SparkHadoopUtil.get.addCredentials(jobConf)
    CarbonInputFormat.setTableInfo(jobConf, carbonTable.getTableInfo)
    val job = new Job(jobConf)
    FileInputFormat.addInputPath(job, new Path(absoluteTableIdentifier.getTablePath))
    CarbonInputFormat
      .setTransactionalTable(job.getConfiguration,
        carbonTable.getTableInfo.isTransactionalTable)
    CarbonInputFormatUtil.setIndexJobIfConfigured(job.getConfiguration)
    (job, carbonInputFormat)
  }
} 
Example 88
Source File: BasicCurrencyConversionExpression.scala    From HANAVora-Extensions   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.sql.currency.basic

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


case class BasicCurrencyConversionExpression(
    conversion: BasicCurrencyConversion,
    children: Seq[Expression])
  extends Expression
  with ImplicitCastInputTypes
  with CodegenFallback {

  protected val AMOUNT_INDEX = 0
  protected val FROM_INDEX = 1
  protected val TO_INDEX = 2
  protected val DATE_INDEX = 3
  protected val NUM_ARGS = 4

  override def eval(input: InternalRow): Any = {
    val inputArguments = children.map(_.eval(input))

    require(inputArguments.length == NUM_ARGS, "wrong number of arguments")

    val sourceCurrency =
      Option(inputArguments(FROM_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val targetCurrency = Option(inputArguments(TO_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val amount = Option(inputArguments(AMOUNT_INDEX).asInstanceOf[Decimal].toJavaBigDecimal)
    val date = Option(inputArguments(DATE_INDEX).asInstanceOf[UTF8String]).map(_.toString)

    (amount, sourceCurrency, targetCurrency, date) match {
      case (Some(a), Some(s), Some(t), Some(d)) => nullSafeEval(a, s, t, d)
      case _ => null
    }
  }

  def nullSafeEval(amount: java.math.BigDecimal,
                   sourceCurrency: String,
                   targetCurrency: String,
                   date: String): Any = {
    conversion.convert(amount, sourceCurrency, targetCurrency, date)
      .get
      .map(Decimal.apply)
      .orNull
  }

  override def dataType: DataType = DecimalType.forType(DoubleType)

  override def nullable: Boolean = true

  // TODO(MD, CS): use DateType but support date string
  override def inputTypes: Seq[AbstractDataType] =
    Seq(DecimalType, StringType, StringType, StringType)
}