org.apache.spark.unsafe.types.UTF8String Scala Examples
The following examples show how to use org.apache.spark.unsafe.types.UTF8String.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: BooleanStatement.scala From spark-snowflake with Apache License 2.0 | 7 votes |
package net.snowflake.spark.snowflake.pushdowns.querygeneration import net.snowflake.spark.snowflake.{ConstantString, SnowflakeSQLStatement} import org.apache.spark.sql.catalyst.expressions.{ Attribute, Contains, EndsWith, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, StartsWith } import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String private[querygeneration] object BooleanStatement { def unapply( expAttr: (Expression, Seq[Attribute]) ): Option[SnowflakeSQLStatement] = { val expr = expAttr._1 val fields = expAttr._2 Option(expr match { case In(child, list) if list.forall(_.isInstanceOf[Literal]) => convertStatement(child, fields) + "IN" + blockStatement(convertStatements(fields, list: _*)) case IsNull(child) => blockStatement(convertStatement(child, fields) + "IS NULL") case IsNotNull(child) => blockStatement(convertStatement(child, fields) + "IS NOT NULL") case Not(child) => { child match { case EqualTo(left, right) => blockStatement( convertStatement(left, fields) + "!=" + convertStatement(right, fields) ) case GreaterThanOrEqual(left, right) => convertStatement(LessThan(left, right), fields) case LessThanOrEqual(left, right) => convertStatement(GreaterThan(left, right), fields) case GreaterThan(left, right) => convertStatement(LessThanOrEqual(left, right), fields) case LessThan(left, right) => convertStatement(GreaterThanOrEqual(left, right), fields) case _ => ConstantString("NOT") + blockStatement(convertStatement(child, fields)) } } case Contains(child, Literal(pattern: UTF8String, StringType)) => convertStatement(child, fields) + "LIKE" + s"'%${pattern.toString}%'" case EndsWith(child, Literal(pattern: UTF8String, StringType)) => convertStatement(child, fields) + "LIKE" + s"'%${pattern.toString}'" case StartsWith(child, Literal(pattern: UTF8String, StringType)) => convertStatement(child, fields) + "LIKE" + s"'${pattern.toString}%'" case _ => null }) } }
Example 2
Source File: ColumnarTestUtils.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericInternalRow = { val row = new GenericInternalRow(length) (0 until length).foreach(row.setNullAt) row } def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) bytes } (columnType match { case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() case LONG => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) case ARRAY(_) => new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } def makeRandomValues( head: ColumnType[_], tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } def makeUniqueRandomValues[JvmType]( columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() }.drop(count).next().toSeq } def makeRandomRow( head: ColumnType[_], tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericInternalRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } row } def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => val row = new GenericInternalRow(1) row(0) = value row } (values, rows) } }
Example 3
Source File: RowSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { val expected = new GenericInternalRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) expected.setNullAt(3) val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { val row = new SpecificInternalRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) assert(row.getAs[Seq[Int]]("b") === Seq(1)) intercept[IllegalArgumentException]{ row.getAs[Int]("c") } } test("float NaN == NaN") { val r1 = Row(Float.NaN) val r2 = Row(Float.NaN) assert(r1 === r2) } test("double NaN == NaN") { val r1 = Row(Double.NaN) val r2 = Row(Double.NaN) assert(r1 === r2) } test("equals and hashCode") { val r1 = Row("Hello") val r2 = Row("Hello") assert(r1 === r2) assert(r1.hashCode() === r2.hashCode()) val r3 = Row("World") assert(r3.hashCode() != r1.hashCode()) } }
Example 4
Source File: FailureSafeParser.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], mode: ParseMode, schema: StructType, columnNameOfCorruptRecord: String) { private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) private val resultRow = new GenericInternalRow(schema.length) private val nullResult = new GenericInternalRow(schema.length) // This function takes 2 parameters: an optional partial result, and the bad record. If the given // schema doesn't contain a field for corrupted record, we just return the partial result or a // row with all fields null. If the given schema contains a field for corrupted record, we will // set the bad record to this field, and set other fields according to the partial result or null. private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { if (corruptFieldIndex.isDefined) { (row, badRecord) => { var i = 0 while (i < actualSchema.length) { val from = actualSchema(i) resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull i += 1 } resultRow(corruptFieldIndex.get) = badRecord() resultRow } } else { (row, _) => row.getOrElse(nullResult) } } def parse(input: IN): Iterator[InternalRow] = { try { rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) } catch { case e: BadRecordException => mode match { case PermissiveMode => Iterator(toResultRow(e.partialResult(), e.record)) case DropMalformedMode => Iterator.empty case FailFastMode => throw new SparkException("Malformed records are detected in record parsing. " + s"Parse Mode: ${FailFastMode.name}.", e.cause) } } } }
Example 5
Source File: ComplexDataSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import scala.collection._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} import org.apache.spark.unsafe.types.UTF8String class ComplexDataSuite extends SparkFunSuite { def utf8(str: String): UTF8String = UTF8String.fromString(str) test("inequality tests for MapData") { // test data val testMap1 = Map(utf8("key1") -> 1) val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2) val testMap3 = Map(utf8("key1") -> 1) val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2) // ArrayBasedMapData val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) assert(testArrayMap1 !== testArrayMap3) assert(testArrayMap2 !== testArrayMap4) // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) val row = new GenericInternalRow(1) def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { row.update(0, map) val unsafeRow = unsafeConverter.apply(row) unsafeRow.getMap(0).copy } assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) } test("GenericInternalRow.copy return a new instance that is independent from the old one") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0))) val copiedGenericRow = genericRow.copy() assert(copiedGenericRow.getString(0) == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied internal row should not be changed externally. assert(copiedGenericRow.getString(0) == "a") } test("SpecificMutableRow.copy return a new instance that is independent from the old one") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val mutableRow = new SpecificInternalRow(Seq(StringType)) mutableRow(0) = unsafeRow.getUTF8String(0) val copiedMutableRow = mutableRow.copy() assert(copiedMutableRow.getString(0) == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied internal row should not be changed externally. assert(copiedMutableRow.getString(0) == "a") } test("GenericArrayData.copy return a new instance that is independent from the old one") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0))) val copiedGenericArray = genericArray.copy() assert(copiedGenericArray.getUTF8String(0).toString == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied array data should not be changed externally. assert(copiedGenericArray.getUTF8String(0).toString == "a") } test("copy on nested complex type") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0)))) val copied = arrayOfRow.copy() assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied data should not be changed externally. assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") } }
Example 6
Source File: NumberConverterSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 7
Source File: GenerateUnsafeProjectionSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenerateUnsafeProjectionSuite extends SparkFunSuite { test("Test unsafe projection string access pattern") { val dataType = (new StructType).add("a", StringType) val exprs = BoundReference(0, dataType, nullable = true) :: Nil val projection = GenerateUnsafeProjection.generate(exprs) val result = projection.apply(InternalRow(AlwaysNull)) assert(!result.isNullAt(0)) assert(result.getStruct(0, 1).isNullAt(0)) } } object AlwaysNull extends InternalRow { override def numFields: Int = 1 override def setNullAt(i: Int): Unit = {} override def copy(): InternalRow = this override def anyNull: Boolean = true override def isNullAt(ordinal: Int): Boolean = true override def update(i: Int, value: Any): Unit = notSupported override def getBoolean(ordinal: Int): Boolean = notSupported override def getByte(ordinal: Int): Byte = notSupported override def getShort(ordinal: Int): Short = notSupported override def getInt(ordinal: Int): Int = notSupported override def getLong(ordinal: Int): Long = notSupported override def getFloat(ordinal: Int): Float = notSupported override def getDouble(ordinal: Int): Double = notSupported override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported override def getUTF8String(ordinal: Int): UTF8String = notSupported override def getBinary(ordinal: Int): Array[Byte] = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = notSupported override def getMap(ordinal: Int): MapData = notSupported override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported private def notSupported: Nothing = throw new UnsupportedOperationException }
Example 8
Source File: MiscExpressionsSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("assert_true") { intercept[RuntimeException] { checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null) } intercept[RuntimeException] { checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null) } intercept[RuntimeException] { checkEvaluation(AssertTrue(Literal.create(null, NullType)), null) } intercept[RuntimeException] { checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null) } checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null) checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) } test("uuid") { def assertIncorrectEval(f: () => Unit): Unit = { intercept[Exception] { f() }.getMessage().contains("Incorrect evaluation") } checkEvaluation(Length(Uuid(Some(0))), 36) val r = new Random() val seed1 = Some(r.nextLong()) val uuid1 = evaluate(Uuid(seed1)).asInstanceOf[UTF8String] checkEvaluation(Uuid(seed1), uuid1.toString) val seed2 = Some(r.nextLong()) val uuid2 = evaluate(Uuid(seed2)).asInstanceOf[UTF8String] assertIncorrectEval(() => checkEvaluationWithoutCodegen(Uuid(seed1), uuid2)) assertIncorrectEval(() => checkEvaluationWithGeneratedMutableProjection(Uuid(seed1), uuid2)) assertIncorrectEval(() => checkEvalutionWithUnsafeProjection(Uuid(seed1), uuid2)) assertIncorrectEval(() => checkEvaluationWithOptimization(Uuid(seed1), uuid2)) } }
Example 9
Source File: StringUtils.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.sql.AnalysisException import org.apache.spark.unsafe.types.UTF8String object StringUtils { def filterPattern(names: Seq[String], pattern: String): Seq[String] = { val funcNames = scala.collection.mutable.SortedSet.empty[String] pattern.trim().split("\\|").foreach { subPattern => try { val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } } catch { case _: PatternSyntaxException => } } funcNames.toSeq } }
Example 10
Source File: CreateJacksonParser.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text import org.apache.spark.unsafe.types.UTF8String private[sql] object CreateJacksonParser extends Serializable { def string(jsonFactory: JsonFactory, record: String): JsonParser = { jsonFactory.createParser(record) } def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = { val bb = record.getByteBuffer assert(bb.hasArray) val bain = new ByteArrayInputStream( bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { jsonFactory.createParser(record.getBytes, 0, record.getLength) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { jsonFactory.createParser(record) } }
Example 11
Source File: converters.scala From scylla-migrator with Apache License 2.0 | 5 votes |
package com.scylladb.migrator import java.nio.charset.StandardCharsets import java.util.UUID import com.datastax.spark.connector.types.{ CustomDriverConverter, NullableTypeConverter, PrimitiveColumnType, TypeConverter } import org.apache.spark.unsafe.types.UTF8String import scala.reflect.runtime.universe.TypeTag case object AnotherCustomUUIDConverter extends NullableTypeConverter[UUID] { def targetTypeTag = implicitly[TypeTag[UUID]] def convertPF = { case x: UUID => x case x: String => UUID.fromString(x) case x: UTF8String => UUID.fromString(new String(x.getBytes, StandardCharsets.UTF_8)) } } case object CustomTimeUUIDType extends PrimitiveColumnType[UUID] { def scalaTypeTag = implicitly[TypeTag[UUID]] def cqlTypeName = "timeuuid" def converterToCassandra = new TypeConverter.OptionToNullConverter(AnotherCustomUUIDConverter) } case object CustomUUIDType extends PrimitiveColumnType[UUID] { def scalaTypeTag = implicitly[TypeTag[UUID]] def cqlTypeName = "uuid" def converterToCassandra = new TypeConverter.OptionToNullConverter(AnotherCustomUUIDConverter) } object CustomUUIDConverter extends CustomDriverConverter { import org.apache.spark.sql.{ types => catalystTypes } import com.datastax.driver.core.DataType import com.datastax.spark.connector.types.ColumnType override val fromDriverRowExtension: PartialFunction[DataType, ColumnType[_]] = { case dataType if dataType.getName == DataType.timeuuid().getName => CustomTimeUUIDType case dataType if dataType.getName == DataType.uuid().getName => CustomUUIDType } override val catalystDataType: PartialFunction[ColumnType[_], catalystTypes.DataType] = { case CustomTimeUUIDType => catalystTypes.StringType case CustomUUIDType => catalystTypes.StringType } }
Example 12
Source File: PrefixComparatorsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.util.collection.unsafe.sort import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("String prefix comparator") {//字符串的前缀比较器 def testPrefixComparison(s1: String, s2: String): Unit = { val utf8string1 = UTF8String.fromString(s1) val utf8string2 = UTF8String.fromString(s2) val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) val cmp = UnsignedBytes.lexicographicalComparator().compare( utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) assert( (prefixComparisonResult == 0 && cmp == 0) || (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) } // scalastyle:off val regressionTests = Table( ("s1", "s2"), ("abc", "世界"), ("你好", "世界"), ("你好123", "你好122") ) // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } test("Binary prefix comparator") {//二进制前缀比较器 def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { val res = x(i).compare(y(i)) if (res != 0) return res } x.length - y.length } def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = { val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x) val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y) val prefixComparisonResult = PrefixComparators.BINARY.compare(s1Prefix, s2Prefix) assert( (prefixComparisonResult == 0) || (prefixComparisonResult < 0 && compareBinary(x, y) < 0) || (prefixComparisonResult > 0 && compareBinary(x, y) > 0)) } // scalastyle:off val regressionTests = Table( ("s1", "s2"), ("abc", "世界"), ("你好", "世界"), ("你好123", "你好122") ) // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) } forAll { (s1: String, s2: String) => testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) } } //双前缀比较器正确处理NaN test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } }
Example 13
Source File: DDLTestSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) } } case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { override def schema: StructType = StructType(Seq(//StructType代表一张表,StructField代表一个字段 StructField("intType", IntegerType, nullable = false, new MetadataBuilder().putString("comment", s"test comment $table").build()), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), StructField("floatType", FloatType, nullable = false), StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType",//StructType代表一张表,StructField代表一个字段 StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) //需要转换 override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { //依靠一个类型删掉黑客通过RDD[internalrow]回到RDD[行] // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sqlContext.sparkContext.parallelize(from to to).map { e => InternalRow(UTF8String.fromString(s"people$e"), e * 2) }.asInstanceOf[RDD[Row]] } } class DDLTestSuite extends DataSourceTest with SharedSQLContext { protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { super.beforeAll() sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource |OPTIONS ( | From '1', | To '10', | Table 'test1' |) """.stripMargin) } sqlTest( "describe ddlPeople", Seq( Row("intType", "int", "test comment test1"), Row("stringType", "string", ""), Row("dateType", "date", ""), Row("timestampType", "timestamp", ""), Row("doubleType", "double", ""), Row("bigintType", "bigint", ""), Row("tinyintType", "tinyint", ""), Row("decimalType", "decimal(10,0)", ""), Row("fixedDecimalType", "decimal(5,1)", ""), Row("binaryType", "binary", ""), Row("booleanType", "boolean", ""), Row("smallIntType", "smallint", ""), Row("floatType", "float", ""), Row("mapType", "map<string,string>", ""), Row("arrayType", "array<string>", ""), Row("structType", "struct<f1:string,f2:int>", "") )) //描述命令应该有正确的物理计划输出属性 test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) } }
Example 14
Source File: ColumnarTestUtils.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.{DataType, Decimal, AtomicType} import org.apache.spark.unsafe.types.UTF8String //列测试工具 object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { val row = new GenericMutableRow(length) (0 until length).foreach(row.setNullAt) row } //产生随机值 def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) bytes } (columnType match { case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() case DATE => Random.nextInt() case LONG => Random.nextLong() case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object //使用随机一元映射代替任意对象 Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) }).asInstanceOf[JvmType] } def makeRandomValues( head: ColumnType[_], tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } //使唯一随机值 def makeUniqueRandomValues[JvmType]( columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() }.drop(count).next().toSeq } def makeRandomRow( head: ColumnType[_], tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } row } //使唯一值和单值行 def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => val row = new GenericMutableRow(1) row(0) = value row } (values, rows) } }
Example 15
Source File: ExtraStrategiesSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package test.org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.UTF8String //快速操作 case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) sparkContext.parallelize(Seq(row)) } //Nil是一个空的List override def children: Seq[SparkPlan] = Nil } //测试策略 object TestStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Project(Seq(attr), _) if attr.name == "a" => //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 FastOperator(attr.toAttribute :: Nil) :: Nil //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 case _ => Nil } } //额外的策略集 class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("insert an extraStrategy") {//插入一个额外的策略 try { //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 sqlContext.experimental.extraStrategies = TestStrategy :: Nil val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( df.select("a"), Row("so fast")) checkAnswer( df.select("a", "b"), Row("so slow", 1)) } finally { //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 sqlContext.experimental.extraStrategies = Nil } } }
Example 16
Source File: SQLImplicits.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.types.StructField import org.apache.spark.unsafe.types.UTF8String implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { val dataType = StringType val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => row.update(0, UTF8String.fromString(v)) row: InternalRow } } DataFrameHolder( _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } }
Example 17
Source File: NumberConverterSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } //转换 test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 18
Source File: GeneratedProjectionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class GeneratedProjectionSuite extends SparkFunSuite { //在更宽的桌子上产生预测 test("generated projections on wider table") { val N = 1000 val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) val wideRow2 = new GenericInternalRow( (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) val schema2 = StructType((1 to N).map(i => StructField("", StringType))) val joined = new JoinedRow(wideRow1, wideRow2) val joinedSchema = StructType(schema1 ++ schema2) val nested = new JoinedRow(InternalRow(joined, joined), joined) val nestedSchema = StructType( Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) // test generated UnsafeProjection val unsafeProj = UnsafeProjection.create(nestedSchema) val unsafe: UnsafeRow = unsafeProj(nested) (0 until N).foreach { i => val s = UTF8String.fromString((i + 1).toString) assert(i + 1 === unsafe.getInt(i + 2)) assert(s === unsafe.getUTF8String(i + 2 + N)) assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) } // test generated SafeProjection val safeProj = FromUnsafeProjection(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => val r = i + 1 val s = UTF8String.fromString((i + 1).toString) assert(r === result.getInt(i + 2)) assert(s === result.getUTF8String(i + 2 + N)) assert(r === result.getStruct(0, N * 2).getInt(i)) assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) assert(r === result.getStruct(1, N * 2).getInt(i)) assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) } // test generated MutableProjection val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) } val mutableProj = GenerateMutableProjection.generate(exprs)() val row1 = mutableProj(result) assert(result === row1) val row2 = mutableProj(result) assert(result === row2) } test("generated unsafe projection with array of binary") { val row = InternalRow( Array[Byte](1, 2), new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4)))) val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType] val unsafeProj = UnsafeProjection.create(fields) val unsafeRow: UnsafeRow = unsafeProj(row) assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2))) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2))) assert(unsafeRow.getArray(1).isNullAt(1)) assert(unsafeRow.getArray(1).getBinary(1) === null) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) val safeProj = FromUnsafeProjection(fields) val row2 = safeProj(unsafeRow) assert(row2 === row) } }
Example 19
Source File: RowSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { val expected = new GenericInternalRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) expected.setNullAt(3) val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { val row = new SpecificInternalRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) assert(row.getAs[Seq[Int]]("b") === Seq(1)) intercept[IllegalArgumentException]{ row.getAs[Int]("c") } } test("float NaN == NaN") { val r1 = Row(Float.NaN) val r2 = Row(Float.NaN) assert(r1 === r2) } test("double NaN == NaN") { val r1 = Row(Double.NaN) val r2 = Row(Double.NaN) assert(r1 === r2) } test("equals and hashCode") { val r1 = Row("Hello") val r2 = Row("Hello") assert(r1 === r2) assert(r1.hashCode() === r2.hashCode()) val r3 = Row("World") assert(r3.hashCode() != r1.hashCode()) } }
Example 20
Source File: ColumnarTestUtils.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericInternalRow = { val row = new GenericInternalRow(length) (0 until length).foreach(row.setNullAt) row } def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) bytes } (columnType match { case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() case LONG => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) case ARRAY(_) => new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } def makeRandomValues( head: ColumnType[_], tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } def makeUniqueRandomValues[JvmType]( columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() }.drop(count).next().toSeq } def makeRandomRow( head: ColumnType[_], tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericInternalRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } row } def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => val row = new GenericInternalRow(1) row(0) = value row } (values, rows) } }
Example 21
Source File: NumberConverterSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 22
Source File: MapDataSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import scala.collection._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} import org.apache.spark.unsafe.types.UTF8String class MapDataSuite extends SparkFunSuite { test("inequality tests") { def u(str: String): UTF8String = UTF8String.fromString(str) // test data val testMap1 = Map(u("key1") -> 1) val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) val testMap3 = Map(u("key1") -> 1) val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) // ArrayBasedMapData val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) assert(testArrayMap1 !== testArrayMap3) assert(testArrayMap2 !== testArrayMap4) // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) val row = new GenericInternalRow(1) def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { row.update(0, map) val unsafeRow = unsafeConverter.apply(row) unsafeRow.getMap(0).copy } assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) } }
Example 23
Source File: StringUtils.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String object StringUtils { // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character def escapeLikeRegex(v: String): String = { if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" case ('\\', c) => c match { case '_' => "_" case '%' => "%" case _ => Pattern.quote("\\" + c) } case (prev, c) => c match { case '_' => "." case '%' => ".*" case _ => Pattern.quote(Character.toString(c)) } }.mkString } else { v } } private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) def filterPattern(names: Seq[String], pattern: String): Seq[String] = { val funcNames = scala.collection.mutable.SortedSet.empty[String] pattern.trim().split("\\|").foreach { subPattern => try { val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } } catch { case _: PatternSyntaxException => } } funcNames.toSeq } }
Example 24
Source File: similarityFunctions.scala From spark-stringmetric with MIT License | 5 votes |
package com.github.mrpowers.spark.stringmetric.expressions import com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions import org.apache.commons.text.similarity.CosineDistance import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{ CodegenContext, ExprCode } import org.apache.spark.sql.types.{ DataType, IntegerType, StringType } trait UTF8StringFunctionsHelper { val stringFuncs: String = "com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions" } trait StringString2IntegerExpression extends ImplicitCastInputTypes with NullIntolerant with UTF8StringFunctionsHelper { self: BinaryExpression => override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType) protected override def nullSafeEval(left: Any, right: Any): Any = -1 } case class HammingDistance(left: Expression, right: Expression) extends BinaryExpression with StringString2IntegerExpression { override def prettyName: String = "hamming" override def nullSafeEval(leftVal: Any, righValt: Any): Any = { val leftStr = left.asInstanceOf[UTF8String] val rightStr = right.asInstanceOf[UTF8String] UTF8StringFunctions.hammingDistance(leftStr, rightStr) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (s1, s2) => s"$stringFuncs.hammingDistance($s1, $s2)") } }
Example 25
Source File: PgWireProtocolSuite.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server.service.postgresql.protocol.v3 import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.sql.SQLException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String class PgWireProtocolSuite extends SparkFunSuite { val conf = new SQLConf() test("DataRow") { val v3Protocol = new PgWireProtocol(65536) val row = new GenericInternalRow(2) row.update(0, 8) row.update(1, UTF8String.fromString("abcdefghij")) val schema = StructType.fromDDL("a INT, b STRING") val rowConverters = PgRowConverters(conf, schema, Seq(true, false)) val data = v3Protocol.DataRow(row, rowConverters) val bytes = ByteBuffer.wrap(data) assert(bytes.get() === 'D'.toByte) assert(bytes.getInt === 28) assert(bytes.getShort === 2) assert(bytes.getInt === 4) assert(bytes.getInt === 8) assert(bytes.getInt === 10) assert(data.slice(19, 30) === "abcdefghij".getBytes(StandardCharsets.UTF_8)) } test("Fails when message buffer overflowed") { val v3Protocol = new PgWireProtocol(4) val row = new GenericInternalRow(1) row.update(0, UTF8String.fromString("abcdefghijk")) val schema = StructType.fromDDL("a STRING") val rowConverters = PgRowConverters(conf, schema, Seq(false)) val errMsg = intercept[SQLException] { v3Protocol.DataRow(row, rowConverters) }.getMessage assert(errMsg.contains( "Cannot generate a V3 protocol message because buffer is not enough for the message. " + "To avoid this exception, you might set higher value at " + "'spark.sql.server.messageBufferSizeInBytes'") ) } }
Example 26
Source File: TestUtils.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.unsafe.types.UTF8String object TestUtils { def toSparkIntervalString(pgIntervalString: String): String = { // See: org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit val sparkIntervalString = UTF8String.fromString(pgIntervalString .replace("years", "year") .replace("mons", "month") .replace("days", "day") .replace("hours", "hour") .replace("mins", "minute") .replace("secs", "second") ) IntervalUtils.stringToInterval(sparkIntervalString).toString } }
Example 27
Source File: IntegerEncodedStringColumnBuffer.scala From spark-vector with Apache License 2.0 | 5 votes |
package com.actian.spark_vector.colbuffer.string import com.actian.spark_vector.colbuffer._ import com.actian.spark_vector.colbuffer.util.StringConversion import com.actian.spark_vector.vector.VectorDataType import org.apache.spark.unsafe.types.UTF8String import java.nio.ByteBuffer private[colbuffer] abstract class IntegerEncodedStringColumnBuffer(p: ColumnBufferBuildParams) extends ColumnBuffer[String, UTF8String](p.name, p.maxValueCount, IntSize, IntSize, p.nullable) { protected final val Whitespace = '\u0020' override def put(source: String, buffer: ByteBuffer): Unit = if (source.isEmpty()) { buffer.putInt(Whitespace) } else { buffer.putInt(encode(source)) } protected def encode(value: String): Int override def get(buffer: ByteBuffer): UTF8String = UTF8String.fromBytes(Character.toChars(buffer.getInt()).map(_.toByte)) } private class ConstantLengthSingleByteStringColumnBuffer(p: ColumnBufferBuildParams) extends IntegerEncodedStringColumnBuffer(p) { override protected def encode(value: String): Int = if (StringConversion.truncateToUTF8Bytes(value, 1).length == 0) { Whitespace } else { value.codePointAt(0) } } private class ConstantLengthSingleCharStringColumnBuffer(p: ColumnBufferBuildParams) extends IntegerEncodedStringColumnBuffer(p) { override protected def encode(value: String): Int = if (Character.isHighSurrogate(value.charAt(0))) { Whitespace } else { value.codePointAt(0) } } private[colbuffer] object IntegerEncodedStringColumnBuffer extends ColumnBufferBuilder { private val buildPartial: PartialFunction[ColumnBufferBuildParams, ColumnBufferBuildParams] = { case p if p.precision == 1 => p } override private[colbuffer] val build: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildPartial andThenPartial { (ofDataType(VectorDataType.CharType) andThen { new ConstantLengthSingleByteStringColumnBuffer(_) }) orElse (ofDataType(VectorDataType.NcharType) andThen { new ConstantLengthSingleCharStringColumnBuffer(_) }) } }
Example 28
Source File: ByteEncodedStringColumnBuffer.scala From spark-vector with Apache License 2.0 | 5 votes |
package com.actian.spark_vector.colbuffer.string import com.actian.spark_vector.colbuffer._ import com.actian.spark_vector.colbuffer.util.StringConversion import com.actian.spark_vector.vector.VectorDataType import org.apache.spark.unsafe.types.UTF8String import java.nio.ByteBuffer private[colbuffer] abstract class ByteEncodedStringColumnBuffer(p: ColumnBufferBuildParams) extends ColumnBuffer[String, UTF8String](p.name, p.maxValueCount, p.precision + 1, ByteSize, p.nullable) { override def put(source: String, buffer: ByteBuffer): Unit = { if (source.exists(c => (c >= '\uD800') && (c <= '\uDFFF') || (c == '\u0000'))) throw new Exception(s"Illegal character in column '${p.name}' in string '$source'") buffer.put(encode(source)) buffer.put(0.toByte) } protected def encode(value: String): Array[Byte] override def get(buffer: ByteBuffer): UTF8String = { private[colbuffer] object ByteEncodedStringColumnBuffer extends ColumnBufferBuilder { private val buildConstLenMultiPartial: PartialFunction[ColumnBufferBuildParams, ColumnBufferBuildParams] = { case p if p.precision > 1 => p } private val buildConstLenMulti: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildConstLenMultiPartial andThenPartial { (ofDataType(VectorDataType.CharType) andThen { new ByteLengthLimitedStringColumnBuffer(_) }) orElse (ofDataType(VectorDataType.NcharType) andThen { new CharLengthLimitedStringColumnBuffer(_) }) } private val buildVarLenPartial: PartialFunction[ColumnBufferBuildParams, ColumnBufferBuildParams] = { case p if p.precision > 0 => p } private val buildVarLen: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildVarLenPartial andThenPartial { (ofDataType(VectorDataType.VarcharType) andThen { new ByteLengthLimitedStringColumnBuffer(_) }) orElse (ofDataType(VectorDataType.NvarcharType) andThen { new CharLengthLimitedStringColumnBuffer(_) }) } override private[colbuffer] val build: PartialFunction[ColumnBufferBuildParams, ColumnBuffer[_, _]] = buildConstLenMulti orElse buildVarLen }
Example 29
Source File: MeanSubstitute.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.sql.expressions import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String import io.projectglow.sql.dsl._ import io.projectglow.sql.util.RewriteAfterResolution case class MeanSubstitute(array: Expression, missingValue: Expression) extends RewriteAfterResolution { override def children: Seq[Expression] = Seq(array, missingValue) def this(array: Expression) = { this(array, Literal(-1)) } private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType // A value is considered missing if it is NaN, null or equal to the missing value parameter def isMissing(arrayElement: Expression): Predicate = IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = { val sumName = Literal(UTF8String.fromString("sum"), StringType) val countName = Literal(UTF8String.fromString("count"), StringType) namedStruct(sumName, sumValue, countName, countValue) } // Update sum and count with array element if not missing def updateSumAndCountConditionally( stateStruct: Expression, arrayElement: Expression): Expression = { If( isMissing(arrayElement), // If value is missing, do not update sum and count stateStruct, // If value is not missing, add to sum and increment count createNamedStruct( stateStruct.getField("sum") + arrayElement, stateStruct.getField("count") + 1) ) } // Calculate mean for imputation def calculateMean(stateStruct: Expression): Expression = { If( stateStruct.getField("count") > 0, // If non-missing values were found, calculate the average stateStruct.getField("sum") / stateStruct.getField("count"), // If all values were missing, substitute with missing value missingValue ) } lazy val arrayMean: Expression = { // Sum and count of non-missing values array.aggregate( createNamedStruct(Literal(0d), Literal(0L)), updateSumAndCountConditionally, calculateMean ) } def substituteWithMean(arrayElement: Expression): Expression = { If(isMissing(arrayElement), arrayMean, arrayElement) } override def rewrite: Expression = { if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) { throw SQLUtils.newAnalysisException( s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.") } if (!missingValue.dataType.isInstanceOf[NumericType]) { throw SQLUtils.newAnalysisException( s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.") } // Replace missing values with the provided strategy array.arrayTransform(substituteWithMean(_)) } }
Example 30
Source File: PlinkRowToInternalRowConverter.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.plink import org.apache.spark.sql.SQLUtils.structFieldsEqualExceptNullability import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types.{ArrayType, StructType} import org.apache.spark.unsafe.types.UTF8String import io.projectglow.common.{GlowLogging, VariantSchemas} import io.projectglow.sql.util.RowConverter class PlinkRowToInternalRowConverter(schema: StructType) extends GlowLogging { private val homAlt = new GenericArrayData(Array(1, 1)) private val missing = new GenericArrayData(Array(-1, -1)) private val het = new GenericArrayData(Array(0, 1)) private val homRef = new GenericArrayData(Array(0, 0)) private def twoBitsToCalls(twoBits: Int): GenericArrayData = { twoBits match { case 0 => homAlt // Homozygous for first (alternate) allele case 1 => missing // Missing genotype case 2 => het // Heterozygous case 3 => homRef // Homozygous for second (reference) allele } } private val converter = { val fns = schema.map { field => val fn: RowConverter.Updater[(Array[UTF8String], Array[Byte])] = field match { case f if f.name == VariantSchemas.genotypesFieldName => val gSchema = f.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType] val converter = makeGenotypeConverter(gSchema) (samplesAndBlock, r, i) => { val genotypes = new Array[Any](samplesAndBlock._1.length) var sampleIdx = 0 while (sampleIdx < genotypes.length) { val sample = samplesAndBlock._1(sampleIdx) // Get the relevant 2 bits for the sample from the block // The i-th sample's call bits are the (i%4)-th pair within the (i/4)-th block val twoBits = samplesAndBlock._2(sampleIdx / 4) >> (2 * (sampleIdx % 4)) & 3 genotypes(sampleIdx) = converter((sample, twoBits)) sampleIdx += 1 } r.update(i, new GenericArrayData(genotypes)) } case _ => // BED file only contains genotypes (_, _, _) => () } fn } new RowConverter[(Array[UTF8String], Array[Byte])](schema, fns.toArray) } private def makeGenotypeConverter(gSchema: StructType): RowConverter[(UTF8String, Int)] = { val functions = gSchema.map { field => val fn: RowConverter.Updater[(UTF8String, Int)] = field match { case f if structFieldsEqualExceptNullability(f, VariantSchemas.sampleIdField) => (sampleAndTwoBits, r, i) => { r.update(i, sampleAndTwoBits._1) } case f if structFieldsEqualExceptNullability(f, VariantSchemas.callsField) => (sampleAndTwoBits, r, i) => r.update(i, twoBitsToCalls(sampleAndTwoBits._2)) case f => logger.info( s"Genotype field $f cannot be derived from PLINK files. It will be null " + s"for each sample." ) (_, _, _) => () } fn } new RowConverter[(UTF8String, Int)](gSchema, functions.toArray) } def convertRow( bimRow: InternalRow, sampleIds: Array[UTF8String], gtBlock: Array[Byte]): InternalRow = { converter((sampleIds, gtBlock), bimRow) } }
Example 31
Source File: UTF8TextOutputFormatter.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.transformers.pipe import java.io.InputStream import scala.collection.JavaConverters._ import org.apache.commons.io.IOUtils import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class UTF8TextOutputFormatter() extends OutputFormatter { override def makeIterator(stream: InputStream): Iterator[Any] = { val schema = StructType(Seq(StructField("text", StringType))) val iter = IOUtils.lineIterator(stream, "UTF-8").asScala.map { s => new GenericInternalRow(Array(UTF8String.fromString(s)): Array[Any]) } Iterator(schema) ++ iter } } class UTF8TextOutputFormatterFactory extends OutputFormatterFactory { override def name: String = "text" override def makeOutputFormatter(options: Map[String, String]): OutputFormatter = { new UTF8TextOutputFormatter } }
Example 32
Source File: CouchbaseSink.scala From couchbase-spark-connector with Apache License 2.0 | 5 votes |
package com.couchbase.spark.sql.streaming import com.couchbase.spark.Logging import org.apache.spark.sql.{DataFrame, SaveMode} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.types.StringType import com.couchbase.spark.sql._ import com.couchbase.spark._ import com.couchbase.client.core.CouchbaseException import com.couchbase.client.java.document.JsonDocument import com.couchbase.client.java.document.json.JsonObject import scala.concurrent.duration._ class CouchbaseSink(options: Map[String, String]) extends Sink with Logging { override def addBatch(batchId: Long, data: DataFrame): Unit = { val bucketName = options.get("bucket").orNull val idFieldName = options.getOrElse("idField", DefaultSource.DEFAULT_DOCUMENT_ID_FIELD) val removeIdField = options.getOrElse("removeIdField", "true").toBoolean val timeout = options.get("timeout").map(v => Duration(v.toLong, MILLISECONDS)) val createDocument = options.get("expiry").map(_.toInt) .map(expiry => (id: String, content: JsonObject) => JsonDocument.create(id, expiry, content)) .getOrElse((id: String, content: JsonObject) => JsonDocument.create(id, content)) data.toJSON .queryExecution .toRdd .map(_.get(0, StringType).asInstanceOf[UTF8String].toString()) .map { rawJson => val encoded = JsonObject.fromJson(rawJson) val id = encoded.get(idFieldName) if (id == null) { throw new Exception(s"Could not find ID field $idFieldName in $encoded") } if (removeIdField) { encoded.removeKey(idFieldName) } createDocument(id.toString, encoded) } .saveToCouchbase(bucketName, StoreMode.UPSERT, timeout) } }
Example 33
Source File: Utils.scala From hbase-connectors with Apache License 2.0 | 5 votes |
package org.apache.hadoop.hbase.spark.datasources import java.sql.{Date, Timestamp} import org.apache.hadoop.hbase.spark.AvroSerdes import org.apache.hadoop.hbase.util.Bytes import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.yetus.audience.InterfaceAudience; @InterfaceAudience.Private object Utils { def hbaseFieldToScalaType( f: Field, src: Array[Byte], offset: Int, length: Int): Any = { if (f.exeSchema.isDefined) { // If we have avro schema defined, use it to get record, and then convert them to catalyst data type val m = AvroSerdes.deserialize(src, f.exeSchema.get) val n = f.avroToCatalyst.map(_(m)) n.get } else { // Fall back to atomic type f.dt match { case BooleanType => src(offset) != 0 case ByteType => src(offset) case ShortType => Bytes.toShort(src, offset) case IntegerType => Bytes.toInt(src, offset) case LongType => Bytes.toLong(src, offset) case FloatType => Bytes.toFloat(src, offset) case DoubleType => Bytes.toDouble(src, offset) case DateType => new Date(Bytes.toLong(src, offset)) case TimestampType => new Timestamp(Bytes.toLong(src, offset)) case StringType => UTF8String.fromBytes(src, offset, length) case BinaryType => val newArray = new Array[Byte](length) System.arraycopy(src, offset, newArray, 0, length) newArray // TODO: SparkSqlSerializer.deserialize[Any](src) case _ => throw new Exception(s"unsupported data type ${f.dt}") } } } // convert input to data type def toBytes(input: Any, field: Field): Array[Byte] = { if (field.schema.isDefined) { // Here we assume the top level type is structType val record = field.catalystToAvro(input) AvroSerdes.serialize(record, field.schema.get) } else { field.dt match { case BooleanType => Bytes.toBytes(input.asInstanceOf[Boolean]) case ByteType => Array(input.asInstanceOf[Number].byteValue) case ShortType => Bytes.toBytes(input.asInstanceOf[Number].shortValue) case IntegerType => Bytes.toBytes(input.asInstanceOf[Number].intValue) case LongType => Bytes.toBytes(input.asInstanceOf[Number].longValue) case FloatType => Bytes.toBytes(input.asInstanceOf[Number].floatValue) case DoubleType => Bytes.toBytes(input.asInstanceOf[Number].doubleValue) case DateType | TimestampType => Bytes.toBytes(input.asInstanceOf[java.util.Date].getTime) case StringType => Bytes.toBytes(input.toString) case BinaryType => input.asInstanceOf[Array[Byte]] case _ => throw new Exception(s"unsupported data type ${field.dt}") } } } }
Example 34
Source File: GenomicIntervalStrategy.scala From bdg-sequila with Apache License 2.0 | 5 votes |
package org.biodatageeks.sequila.utvf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.{DataFrame, GenomicInterval, SparkSession, Strategy} import org.apache.spark.unsafe.types.UTF8String case class GIntervalRow(contigName: String, start: Int, end: Int) class GenomicIntervalStrategy( spark: SparkSession) extends Strategy with Serializable { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case GenomicInterval(contigName, start, end,output) => GenomicIntervalPlan(plan,spark,GIntervalRow(contigName,start,end),output) :: Nil case _ => Nil } } case class GenomicIntervalPlan(plan: LogicalPlan, spark: SparkSession,interval:GIntervalRow, output: Seq[Attribute]) extends SparkPlan with Serializable { def doExecute(): org.apache.spark.rdd.RDD[InternalRow] = { import spark.implicits._ lazy val genomicInterval = spark.createDataset(Seq(interval)) genomicInterval .rdd .map(r=>{ val proj = UnsafeProjection.create(schema) proj.apply(InternalRow.fromSeq(Seq(UTF8String.fromString(r.contigName),r.start,r.end))) } ) } def children: Seq[SparkPlan] = Nil }
Example 35
Source File: Serialize.scala From morpheus with Apache License 2.0 | 5 votes |
package org.opencypher.morpheus.impl.expressions import java.io.ByteArrayOutputStream import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, _} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.opencypher.morpheus.impl.expressions.EncodeLong.encodeLong import org.opencypher.morpheus.impl.expressions.Serialize._ import org.opencypher.okapi.impl.exception case class Serialize(children: Seq[Expression]) extends Expression { override def dataType: DataType = BinaryType override def nullable: Boolean = false // TODO: Only write length if more than one column is serialized override def eval(input: InternalRow): Any = { // TODO: Reuse from a pool instead of allocating a new one for each serialization val out = new ByteArrayOutputStream() children.foreach { child => child.dataType match { case BinaryType => write(child.eval(input).asInstanceOf[Array[Byte]], out) case StringType => write(child.eval(input).asInstanceOf[UTF8String], out) case IntegerType => write(child.eval(input).asInstanceOf[Int], out) case LongType => write(child.eval(input).asInstanceOf[Long], out) case other => throw exception.UnsupportedOperationException(s"Cannot serialize Spark data type $other.") } } out.toByteArray } override protected def doGenCode( ctx: CodegenContext, ev: ExprCode ): ExprCode = { ev.isNull = FalseLiteral val out = ctx.freshName("out") val serializeChildren = children.map { child => val childEval = child.genCode(ctx) s"""|${childEval.code} |if (!${childEval.isNull}) { | ${Serialize.getClass.getName.dropRight(1)}.write(${childEval.value}, $out); |}""".stripMargin }.mkString("\n") val baos = classOf[ByteArrayOutputStream].getName ev.copy( code = code"""|$baos $out = new $baos(); |$serializeChildren |byte[] ${ev.value} = $out.toByteArray();""".stripMargin) } } object Serialize { val supportedTypes: Set[DataType] = Set(BinaryType, StringType, IntegerType, LongType) @inline final def write(value: Array[Byte], out: ByteArrayOutputStream): Unit = { out.write(encodeLong(value.length)) out.write(value) } @inline final def write( value: Boolean, out: ByteArrayOutputStream ): Unit = write(if (value) 1.toLong else 0.toLong, out) @inline final def write(value: Byte, out: ByteArrayOutputStream): Unit = write(value.toLong, out) @inline final def write(value: Int, out: ByteArrayOutputStream): Unit = write(value.toLong, out) @inline final def write(value: Long, out: ByteArrayOutputStream): Unit = write(encodeLong(value), out) @inline final def write(value: UTF8String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out) @inline final def write(value: String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out) }
Example 36
Source File: PrefixComparatorsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util.collection.unsafe.sort import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("String prefix comparator") { def testPrefixComparison(s1: String, s2: String): Unit = { val utf8string1 = UTF8String.fromString(s1) val utf8string2 = UTF8String.fromString(s2) val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) val cmp = UnsignedBytes.lexicographicalComparator().compare( utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) assert( (prefixComparisonResult == 0 && cmp == 0) || (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) } // scalastyle:off val regressionTests = Table( ("s1", "s2"), ("abc", "世界"), ("你好", "世界"), ("你好123", "你好122") ) // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } test("Binary prefix comparator") { def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { val res = x(i).compare(y(i)) if (res != 0) return res } x.length - y.length } def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = { val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x) val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y) val prefixComparisonResult = PrefixComparators.BINARY.compare(s1Prefix, s2Prefix) assert( (prefixComparisonResult == 0) || (prefixComparisonResult < 0 && compareBinary(x, y) < 0) || (prefixComparisonResult > 0 && compareBinary(x, y) > 0)) } // scalastyle:off val regressionTests = Table( ("s1", "s2"), ("abc", "世界"), ("你好", "世界"), ("你好123", "你好122") ) // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) } forAll { (s1: String, s2: String) => testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) } } test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } }
Example 37
Source File: DDLTestSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) } } case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, new MetadataBuilder().putString("comment", s"test comment $table").build()), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), StructField("floatType", FloatType, nullable = false), StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType", StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sqlContext.sparkContext.parallelize(from to to).map { e => InternalRow(UTF8String.fromString(s"people$e"), e * 2) }.asInstanceOf[RDD[Row]] } } class DDLTestSuite extends DataSourceTest with SharedSQLContext { protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { super.beforeAll() sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource |OPTIONS ( | From '1', | To '10', | Table 'test1' |) """.stripMargin) } sqlTest( "describe ddlPeople", Seq( Row("intType", "int", "test comment test1"), Row("stringType", "string", ""), Row("dateType", "date", ""), Row("timestampType", "timestamp", ""), Row("doubleType", "double", ""), Row("bigintType", "bigint", ""), Row("tinyintType", "tinyint", ""), Row("decimalType", "decimal(10,0)", ""), Row("fixedDecimalType", "decimal(5,1)", ""), Row("binaryType", "binary", ""), Row("booleanType", "boolean", ""), Row("smallIntType", "smallint", ""), Row("floatType", "float", ""), Row("mapType", "map<string,string>", ""), Row("arrayType", "array<string>", ""), Row("structType", "struct<f1:string,f2:int>", "") )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) } }
Example 38
Source File: ExtraStrategiesSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package test.org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.UTF8String case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) sparkContext.parallelize(Seq(row)) } override def children: Seq[SparkPlan] = Nil } object TestStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Project(Seq(attr), _) if attr.name == "a" => FastOperator(attr.toAttribute :: Nil) :: Nil case _ => Nil } } class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("insert an extraStrategy") { try { sqlContext.experimental.extraStrategies = TestStrategy :: Nil val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( df.select("a"), Row("so fast")) checkAnswer( df.select("a", "b"), Row("so slow", 1)) } finally { sqlContext.experimental.extraStrategies = Nil } } }
Example 39
Source File: RowSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) expected.setNullAt(3) val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { val row = new SpecificMutableRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() val serializer = new SparkSqlSerializer(sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] assert(de === row) } test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) assert(row.getAs[Seq[Int]]("b") === Seq(1)) intercept[IllegalArgumentException]{ row.getAs[Int]("c") } } test("float NaN == NaN") { val r1 = Row(Float.NaN) val r2 = Row(Float.NaN) assert(r1 === r2) } test("double NaN == NaN") { val r1 = Row(Double.NaN) val r2 = Row(Double.NaN) assert(r1 === r2) } test("equals and hashCode") { val r1 = Row("Hello") val r2 = Row("Hello") assert(r1 === r2) assert(r1.hashCode() === r2.hashCode()) val r3 = Row("World") assert(r3.hashCode() != r1.hashCode()) } }
Example 40
Source File: ColumnarTestUtils.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { val row = new GenericMutableRow(length) (0 until length).foreach(row.setNullAt) row } def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) bytes } (columnType match { case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() case LONG => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) case ARRAY(_) => new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) }).asInstanceOf[JvmType] } def makeRandomValues( head: ColumnType[_], tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } def makeUniqueRandomValues[JvmType]( columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() }.drop(count).next().toSeq } def makeRandomRow( head: ColumnType[_], tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } row } def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => val row = new GenericMutableRow(1) row(0) = value row } (values, rows) } }
Example 41
Source File: NumberConverterSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 42
Source File: StringUtils.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import java.util.regex.Pattern import org.apache.spark.unsafe.types.UTF8String object StringUtils { // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character def escapeLikeRegex(v: String): String = { if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" case ('\\', c) => c match { case '_' => "_" case '%' => "%" case _ => Pattern.quote("\\" + c) } case (prev, c) => c match { case '_' => "." case '%' => ".*" case _ => Pattern.quote(Character.toString(c)) } }.mkString } else { v } } private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) }
Example 43
Source File: KinesisRecordToUnsafeRowConverter.scala From kinesis-sql with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.kinesis import com.amazonaws.services.kinesis.model.Record import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String private[kinesis] class KinesisRecordToUnsafeRowConverter { private val rowWriter = new UnsafeRowWriter(5) def toUnsafeRow(record: Record, streamName: String): UnsafeRow = { rowWriter.reset() rowWriter.write(0, record.getData.array()) rowWriter.write(1, UTF8String.fromString(streamName)) rowWriter.write(2, UTF8String.fromString(record.getPartitionKey)) rowWriter.write(3, UTF8String.fromString(record.getSequenceNumber)) rowWriter.write(4, DateTimeUtils.fromJavaTimestamp( new java.sql.Timestamp(record.getApproximateArrivalTimestamp.getTime))) rowWriter.getRow } }
Example 44
Source File: PDFDataSource.scala From mimir with Apache License 2.0 | 5 votes |
package mimir.exec.spark.datasource.pdf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types._ import org.apache.spark.sql.sources.v2.reader._ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import java.util.{Collections, List => JList, Optional} import org.apache.spark.unsafe.types.UTF8String import mimir.exec.spark.datasource.csv.CSVDataSourceReader class DefaultSource extends DataSourceV2 with ReadSupport { def createReader(options: DataSourceOptions) = { val path = options.get("path").get val pages = options.get("pages").orElse("all") val area = Option(options.get("area").orElse(null)) val hasGrid = options.get("gridLines").orElse("false").toBoolean val pdfExtractor = new PDFTableExtractor() val outPath = s"${path}.csv" pdfExtractor.defaultExtract(path, pages, area, Some(outPath), hasGrid) //println(s"------PDFDataSource----$path -> $outPath") //println({scala.io.Source.fromFile(outPath).mkString}) new CSVDataSourceReader(outPath, options.asMap().asScala.toMap) } }
Example 45
Source File: DataSourceTest.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String private[sql] abstract class DataSourceTest extends QueryTest { protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) { test(sqlString) { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) { checkAnswer(spark.sql(sqlString), expectedAnswer) } } } } class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { SimpleDDLScan( parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext.sparkSession) } } case class SimpleDDLScan( from: Int, to: Int, table: String)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), StructField("floatType", FloatType, nullable = false), StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType", StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sparkSession.sparkContext.parallelize(from to to).map { e => InternalRow(UTF8String.fromString(s"people$e"), e * 2) }.asInstanceOf[RDD[Row]] } }
Example 46
Source File: XmlDataToCatalyst.scala From spark-xml with Apache License 2.0 | 5 votes |
package com.databricks.spark.xml import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import com.databricks.spark.xml.parsers.StaxXmlParser case class XmlDataToCatalyst( child: Expression, schema: DataType, options: XmlOptions) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { override lazy val dataType: DataType = schema @transient lazy val rowSchema: StructType = schema match { case st: StructType => st case ArrayType(st: StructType, _) => st } override def nullSafeEval(xml: Any): Any = xml match { case string: UTF8String => CatalystTypeConverters.convertToCatalyst( StaxXmlParser.parseColumn(string.toString, rowSchema, options)) case string: String => StaxXmlParser.parseColumn(string, rowSchema, options) case arr: GenericArrayData => CatalystTypeConverters.convertToCatalyst( arr.array.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options))) case arr: Array[_] => arr.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options)) case _ => null } override def inputTypes: Seq[DataType] = schema match { case _: StructType => Seq(StringType) case ArrayType(_: StructType, _) => Seq(ArrayType(StringType)) } }
Example 47
Source File: ColumnarTestUtils.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericInternalRow = { val row = new GenericInternalRow(length) (0 until length).foreach(row.setNullAt) row } def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) bytes } (columnType match { case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() case LONG => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) case ARRAY(_) => new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } def makeRandomValues( head: ColumnType[_], tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } def makeUniqueRandomValues[JvmType]( columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() }.drop(count).next().toSeq } def makeRandomRow( head: ColumnType[_], tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericInternalRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } row } def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => val row = new GenericInternalRow(1) row(0) = value row } (values, rows) } }
Example 48
Source File: FailureSafeParser.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], mode: ParseMode, schema: StructType, columnNameOfCorruptRecord: String) { private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) private val resultRow = new GenericInternalRow(schema.length) private val nullResult = new GenericInternalRow(schema.length) // This function takes 2 parameters: an optional partial result, and the bad record. If the given // schema doesn't contain a field for corrupted record, we just return the partial result or a // row with all fields null. If the given schema contains a field for corrupted record, we will // set the bad record to this field, and set other fields according to the partial result or null. private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { if (corruptFieldIndex.isDefined) { (row, badRecord) => { var i = 0 while (i < actualSchema.length) { val from = actualSchema(i) resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull i += 1 } resultRow(corruptFieldIndex.get) = badRecord() resultRow } } else { (row, _) => row.getOrElse(nullResult) } } def parse(input: IN): Iterator[InternalRow] = { try { rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) } catch { case e: BadRecordException => mode match { case PermissiveMode => Iterator(toResultRow(e.partialResult(), e.record)) case DropMalformedMode => Iterator.empty case FailFastMode => throw new SparkException("Malformed records are detected in record parsing. " + s"Parse Mode: ${FailFastMode.name}.", e.cause) } } } }
Example 49
Source File: NumberConverterSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 50
Source File: SortOrderExpressionsSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.util.TimeZone import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SortPrefix") { val b1 = Literal.create(false, BooleanType) val b2 = Literal.create(true, BooleanType) val i1 = Literal.create(20132983, IntegerType) val i2 = Literal.create(-20132983, IntegerType) val l1 = Literal.create(20132983, LongType) val l2 = Literal.create(-20132983, LongType) val millis = 1524954911000L; // Explicitly choose a time zone, since Date objects can create different values depending on // local time zone of the machine on which the test is running val oldDefaultTZ = TimeZone.getDefault val d1 = try { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) Literal.create(new java.sql.Date(millis), DateType) } finally { TimeZone.setDefault(oldDefaultTZ) } val t1 = Literal.create(new Timestamp(millis), TimestampType) val f1 = Literal.create(0.7788229f, FloatType) val f2 = Literal.create(-0.7788229f, FloatType) val db1 = Literal.create(0.7788229d, DoubleType) val db2 = Literal.create(-0.7788229d, DoubleType) val s1 = Literal.create("T", StringType) val s2 = Literal.create("This is longer than 8 characters", StringType) val bin1 = Literal.create(Array[Byte](12), BinaryType) val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]), BinaryType) val dec1 = Literal(Decimal(20132983L, 10, 2)) val dec2 = Literal(Decimal(20132983L, 19, 2)) val dec3 = Literal(Decimal(20132983L, 21, 2)) val list1 = Literal(List(1, 2), ArrayType(IntegerType)) val nullVal = Literal.create(null, IntegerType) checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L) checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L) checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L) checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L) checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L) // For some reason, the Literal.create code gives us the number of days since the epoch checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L) checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000) checkEvaluation(SortPrefix(SortOrder(f1, Ascending)), DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble)) checkEvaluation(SortPrefix(SortOrder(f2, Ascending)), DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble)) checkEvaluation(SortPrefix(SortOrder(db1, Ascending)), DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double])) checkEvaluation(SortPrefix(SortOrder(db2, Ascending)), DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double])) checkEvaluation(SortPrefix(SortOrder(s1, Ascending)), StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String])) checkEvaluation(SortPrefix(SortOrder(s2, Ascending)), StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String])) checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)), BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]])) checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)), BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]])) checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L) checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L) checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)), DoublePrefixComparator.computePrefix(201329.83d)) checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) } }
Example 51
Source File: NumberConverter.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.unsafe.types.UTF8String object NumberConverter { def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { return null } if (n.length == 0) { return null } var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) // Copy the digits in the right side of the array val temp = new Array[Byte](64) var i = 1 while (i <= n.length - first) { temp(temp.length - i) = n(n.length - i) i += 1 } char2byte(fromBase, temp.length - n.length + first, temp) // Do the conversion by going through a 64 bit integer var v = encode(fromBase, temp.length - n.length + first, temp) if (negative && toBase > 0) { if (v < 0) { v = -1 } else { v = -v } } if (toBase < 0 && v < 0) { v = -v negative = true } decode(v, Math.abs(toBase), temp) // Find the first non-zero digit or the last digits if all are zero. val firstNonZeroPos = { val firstNonZero = temp.indexWhere( _ != 0) if (firstNonZero != -1) firstNonZero else temp.length - 1 } byte2char(Math.abs(toBase), firstNonZeroPos, temp) var resultStartPos = firstNonZeroPos if (negative && toBase < 0) { resultStartPos = firstNonZeroPos - 1 temp(resultStartPos) = '-' } UTF8String.fromBytes(java.util.Arrays.copyOfRange(temp, resultStartPos, temp.length)) } }
Example 52
Source File: StringUtils.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.sql.AnalysisException import org.apache.spark.unsafe.types.UTF8String object StringUtils { def filterPattern(names: Seq[String], pattern: String): Seq[String] = { val funcNames = scala.collection.mutable.SortedSet.empty[String] pattern.trim().split("\\|").foreach { subPattern => try { val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } } catch { case _: PatternSyntaxException => } } funcNames.toSeq } }
Example 53
Source File: CreateJacksonParser.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} import java.nio.channels.Channels import java.nio.charset.Charset import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text import sun.nio.cs.StreamDecoder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.UTF8String private[sql] object CreateJacksonParser extends Serializable { def string(jsonFactory: JsonFactory, record: String): JsonParser = { jsonFactory.createParser(record) } def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = { val bb = record.getByteBuffer assert(bb.hasArray) val bain = new ByteArrayInputStream( bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { jsonFactory.createParser(record.getBytes, 0, record.getLength) } // Jackson parsers can be ranked according to their performance: // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically // by checking leading bytes of the array. // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected // automatically by analyzing first bytes of the input stream. // 3. Reader based parser. This is the slowest parser used here but it allows to create // a reader with specific encoding. // The method creates a reader for an array with given encoding and sets size of internal // decoding buffer according to size of input array. private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = { val bais = new ByteArrayInputStream(in, 0, length) val byteChannel = Channels.newChannel(bais) val decodingBufferSize = Math.min(length, 8192) val decoder = Charset.forName(enc).newDecoder() StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize) } def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = { val sd = getStreamDecoder(enc, record.getBytes, record.getLength) jsonFactory.createParser(sd) } def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = { jsonFactory.createParser(is) } def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { jsonFactory.createParser(new InputStreamReader(is, enc)) } def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { val ba = row.getBinary(0) jsonFactory.createParser(ba, 0, ba.length) } def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { val binary = row.getBinary(0) val sd = getStreamDecoder(enc, binary, binary.length) jsonFactory.createParser(sd) } }
Example 54
Source File: misc.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @ExpressionDescription( usage = "_FUNC_() - Returns the current database.", examples = """ Examples: > SELECT _FUNC_(); default """) case class CurrentDatabase() extends LeafExpression with Unevaluable { override def dataType: DataType = StringType override def foldable: Boolean = true override def nullable: Boolean = false override def prettyName: String = "current_database" } // scalastyle:off line.size.limit @ExpressionDescription( usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""", examples = """ Examples: > SELECT _FUNC_(); 46707d92-02f4-4817-8116-a4c3b23e6266 """, note = "The function is non-deterministic.") // scalastyle:on line.size.limit case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful with ExpressionWithRandomSeed { def this() = this(None) override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false override def dataType: DataType = StringType @transient private[this] var randomGenerator: RandomUUIDGenerator = _ override protected def initializeInternal(partitionIndex: Int): Unit = randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex) override protected def evalInternal(input: InternalRow): Any = randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val randomGen = ctx.freshName("randomGen") ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, forceInline = true, useFreshName = false) ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } override def freshCopy(): Uuid = Uuid(randomSeed) }
Example 55
Source File: inputFileBlock.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @ExpressionDescription( usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.") case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = StringType override def prettyName: String = "input_file_name" override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileBlockHolder.getInputFilePath } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") val typeDef = s"final ${CodeGenerator.javaType(dataType)}" ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", isNull = FalseLiteral) } } @ExpressionDescription( usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.") case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = LongType override def prettyName: String = "input_file_block_start" override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Long = { InputFileBlockHolder.getStartOffset } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") val typeDef = s"final ${CodeGenerator.javaType(dataType)}" ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @ExpressionDescription( usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.") case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = LongType override def prettyName: String = "input_file_block_length" override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Long = { InputFileBlockHolder.getLength } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") val typeDef = s"final ${CodeGenerator.javaType(dataType)}" ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } }
Example 56
Source File: InternalRow.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { case BooleanType => (input, ordinal) => input.getBoolean(ordinal) case ByteType => (input, ordinal) => input.getByte(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal) case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal) case DoubleType => (input, ordinal) => input.getDouble(ordinal) case StringType => (input, ordinal) => input.getUTF8String(ordinal) case BinaryType => (input, ordinal) => input.getBinary(ordinal) case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) case _: ArrayType => (input, ordinal) => input.getArray(ordinal) case _: MapType => (input, ordinal) => input.getMap(ordinal) case u: UserDefinedType[_] => getAccessor(u.sqlType) case _ => (input, ordinal) => input.get(ordinal, dataType) } }
Example 57
Source File: DeltaRetentionSuiteBase.scala From delta with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.delta import java.io.File import org.apache.spark.sql.delta.DeltaOperations.Truncate import org.apache.spark.sql.delta.actions.Metadata import org.apache.spark.sql.delta.util.FileNames import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.unsafe.types.UTF8String trait DeltaRetentionSuiteBase extends QueryTest with SharedSparkSession { protected val testOp = Truncate() protected override def sparkConf: SparkConf = super.sparkConf // Disable the log cleanup because it runs asynchronously and causes test flakiness .set("spark.databricks.delta.properties.defaults.enableExpiredLogCleanup", "false") protected def intervalStringToMillis(str: String): Long = { DeltaConfigs.getMilliSeconds( IntervalUtils.safeStringToInterval(UTF8String.fromString(str))) } protected def getDeltaFiles(dir: File): Seq[File] = dir.listFiles().filter(_.getName.endsWith(".json")) protected def getCheckpointFiles(dir: File): Seq[File] = dir.listFiles().filter(f => FileNames.isCheckpointFile(new Path(f.getCanonicalPath))) protected def getLogFiles(dir: File): Seq[File] protected def startTxnWithManualLogCleanup(log: DeltaLog): OptimisticTransaction = { val txn = log.startTransaction() // This will pick up `spark.databricks.delta.properties.defaults.enableExpiredLogCleanup` to // disable log cleanup. txn.updateMetadata(Metadata()) txn } test("startTxnWithManualLogCleanup") { withTempDir { tempDir => val log = DeltaLog(spark, new Path(tempDir.getCanonicalPath)) startTxnWithManualLogCleanup(log).commit(Nil, testOp) assert(!log.enableExpiredLogCleanup) } } }
Example 58
Source File: DataConverter.scala From spark-cdm with MIT License | 5 votes |
package com.microsoft.cdm.utils import java.text.SimpleDateFormat import java.util.{Locale, TimeZone} import java.sql.Timestamp import org.apache.commons.lang.time.DateUtils import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class DataConverter() extends Serializable { val dateFormatter = new SimpleDateFormat(Constants.SINGLE_DATE_FORMAT) val timestampFormatter = TimestampFormatter(Constants.TIMESTAMP_FORMAT, TimeZone.getTimeZone("UTC")) val toSparkType: Map[CDMDataType.Value, DataType] = Map( CDMDataType.int64 -> LongType, CDMDataType.dateTime -> DateType, CDMDataType.string -> StringType, CDMDataType.double -> DoubleType, CDMDataType.decimal -> DecimalType(Constants.DECIMAL_PRECISION,0), CDMDataType.boolean -> BooleanType, CDMDataType.dateTimeOffset -> TimestampType ) def jsonToData(dt: DataType, value: String): Any = { return dt match { case LongType => value.toLong case DoubleType => value.toDouble case DecimalType() => Decimal(value) case BooleanType => value.toBoolean case DateType => dateFormatter.parse(value) case TimestampType => timestampFormatter.parse(value) case _ => UTF8String.fromString(value) } } def toCdmType(dt: DataType): CDMDataType.Value = { return dt match { case IntegerType => CDMDataType.int64 case LongType => CDMDataType.int64 case DateType => CDMDataType.dateTime case StringType => CDMDataType.string case DoubleType => CDMDataType.double case DecimalType() => CDMDataType.decimal case BooleanType => CDMDataType.boolean case TimestampType => CDMDataType.dateTimeOffset } } def dataToString(data: Any, dataType: DataType): String = { (dataType, data) match { case (_, null) => null case (DateType, _) => dateFormatter.format(data) case (TimestampType, v: Number) => timestampFormatter.format(data.asInstanceOf[Long]) case _ => data.toString } } }
Example 59
Source File: CardinalityHashFunctionTest.scala From spark-alchemy with Apache License 2.0 | 5 votes |
package com.swoop.alchemy.spark.expressions.hll import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.scalatest.{Matchers, WordSpec} class CardinalityHashFunctionTest extends WordSpec with Matchers { "Cardinality hash functions" should { "account for nulls" in { val a = UTF8String.fromString("a") allDistinct(Seq( null, Array.empty[Byte], Array.apply(1.toByte) ), BinaryType) allDistinct(Seq( null, UTF8String.fromString(""), a ), StringType) allDistinct(Seq( null, ArrayData.toArrayData(Array.empty), ArrayData.toArrayData(Array(null)), ArrayData.toArrayData(Array(null, null)), ArrayData.toArrayData(Array(a, null)), ArrayData.toArrayData(Array(null, a)) ), ArrayType(StringType)) allDistinct(Seq( null, ArrayBasedMapData(Map.empty), ArrayBasedMapData(Map(null.asInstanceOf[String] -> null)) ), MapType(StringType, StringType)) allDistinct(Seq( null, InternalRow(null), InternalRow(a) ), new StructType().add("foo", StringType)) allDistinct(Seq( InternalRow(null, a), InternalRow(a, null) ), new StructType().add("foo", StringType).add("bar", StringType)) } } def allDistinct(values: Seq[Any], dataType: DataType): Unit = { val hashed = values.map(x => CardinalityXxHash64Function.hash(x, dataType, 0)) hashed.distinct.length should be(hashed.length) } }
Example 60
Source File: ColumnarTestUtils.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericInternalRow = { val row = new GenericInternalRow(length) (0 until length).foreach(row.setNullAt) row } def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) bytes } (columnType match { case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() case LONG => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) case ARRAY(_) => new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } def makeRandomValues( head: ColumnType[_], tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } def makeUniqueRandomValues[JvmType]( columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() }.drop(count).next().toSeq } def makeRandomRow( head: ColumnType[_], tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericInternalRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } row } def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => val row = new GenericInternalRow(1) row(0) = value row } (values, rows) } }
Example 61
Source File: NumberConverterSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 62
Source File: MapDataSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import scala.collection._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} import org.apache.spark.unsafe.types.UTF8String class MapDataSuite extends SparkFunSuite { test("inequality tests") { def u(str: String): UTF8String = UTF8String.fromString(str) // test data val testMap1 = Map(u("key1") -> 1) val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) val testMap3 = Map(u("key1") -> 1) val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) // ArrayBasedMapData val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) assert(testArrayMap1 !== testArrayMap3) assert(testArrayMap2 !== testArrayMap4) // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) val row = new GenericInternalRow(1) def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { row.update(0, map) val unsafeRow = unsafeConverter.apply(row) unsafeRow.getMap(0).copy } assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) } }
Example 63
Source File: NumberConverter.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.unsafe.types.UTF8String object NumberConverter { def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { return null } if (n.length == 0) { return null } var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) // Copy the digits in the right side of the array val temp = new Array[Byte](64) var i = 1 while (i <= n.length - first) { temp(temp.length - i) = n(n.length - i) i += 1 } char2byte(fromBase, temp.length - n.length + first, temp) // Do the conversion by going through a 64 bit integer var v = encode(fromBase, temp.length - n.length + first, temp) if (negative && toBase > 0) { if (v < 0) { v = -1 } else { v = -v } } if (toBase < 0 && v < 0) { v = -v negative = true } decode(v, Math.abs(toBase), temp) // Find the first non-zero digit or the last digits if all are zero. val firstNonZeroPos = { val firstNonZero = temp.indexWhere( _ != 0) if (firstNonZero != -1) firstNonZero else temp.length - 1 } byte2char(Math.abs(toBase), firstNonZeroPos, temp) var resultStartPos = firstNonZeroPos if (negative && toBase < 0) { resultStartPos = firstNonZeroPos - 1 temp(resultStartPos) = '-' } UTF8String.fromBytes(java.util.Arrays.copyOfRange(temp, resultStartPos, temp.length)) } }
Example 64
Source File: StringUtils.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String object StringUtils { // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character def escapeLikeRegex(v: String): String = { if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" case ('\\', c) => c match { case '_' => "_" case '%' => "%" case _ => Pattern.quote("\\" + c) } case (prev, c) => c match { case '_' => "." case '%' => ".*" case _ => Pattern.quote(Character.toString(c)) } }.mkString } else { v } } private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) def filterPattern(names: Seq[String], pattern: String): Seq[String] = { val funcNames = scala.collection.mutable.SortedSet.empty[String] pattern.trim().split("\\|").foreach { subPattern => try { val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } } catch { case _: PatternSyntaxException => } } funcNames.toSeq } }
Example 65
Source File: InputFileName.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileNameHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @ExpressionDescription( usage = "_FUNC_() - Returns the name of the current file being read if available", extended = "> SELECT _FUNC_();\n ''") case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = true override def dataType: DataType = StringType override def prettyName: String = "input_file_name" override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileNameHolder.getInputFileName() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false") } }
Example 66
Source File: HashingTF.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.feature import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.unsafe.hash.Murmur3_x86_32._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils private[spark] def murmur3Hash(term: Any): Int = { term match { case null => seed case b: Boolean => hashInt(if (b) 1 else 0, seed) case b: Byte => hashInt(b, seed) case s: Short => hashInt(s, seed) case i: Int => hashInt(i, seed) case l: Long => hashLong(l, seed) case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } } }
Example 67
Source File: StringUtils.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String object StringUtils { // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character def escapeLikeRegex(v: String): String = { if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" case ('\\', c) => c match { case '_' => "_" case '%' => "%" case _ => Pattern.quote("\\" + c) } case (prev, c) => c match { case '_' => "." case '%' => ".*" case _ => Pattern.quote(Character.toString(c)) } }.mkString } else { v } } private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) def filterPattern(names: Seq[String], pattern: String): Seq[String] = { val funcNames = scala.collection.mutable.SortedSet.empty[String] pattern.trim().split("\\|").foreach { subPattern => try { val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } } catch { case _: PatternSyntaxException => } } funcNames.toSeq } }
Example 68
Source File: SparkDateTime.scala From spark-datetime with Apache License 2.0 | 5 votes |
package org.apache.spark.sparklinedata.datetime import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @SQLUserDefinedType(udt = classOf[SparkDateTimeUDT]) case class SparkDateTime(millis : Long, tzId : String) class SparkDateTimeUDT extends UserDefinedType[SparkDateTime] { override def sqlType: DataType = StructType(Seq(StructField("millis", LongType), StructField("tz", StringType))) override def serialize(obj: SparkDateTime): InternalRow = { obj match { case dt: SparkDateTime => val row = new GenericMutableRow(2) row.setLong(0, dt.millis) row.update(1, CatalystTypeConverters.convertToCatalyst(dt.tzId)) row } } override def deserialize(datum: Any): SparkDateTime = { datum match { case row: InternalRow => require(row.numFields == 2, s"SparkDateTimeUDT.deserialize given row with length ${row.numFields} " + s"but requires length == 2") SparkDateTime(row.getLong(0), row.getString(1)) } } override def userClass: Class[SparkDateTime] = classOf[SparkDateTime] override def asNullable: SparkDateTimeUDT = this } @SQLUserDefinedType(udt = classOf[SparkPeriodUDT]) case class SparkPeriod(periodIsoStr : String) class SparkPeriodUDT extends UserDefinedType[SparkPeriod] { override def sqlType: DataType = StringType override def serialize(obj: SparkPeriod): Any = { obj match { case p: SparkPeriod => CatalystTypeConverters.convertToCatalyst(p.periodIsoStr) } } override def deserialize(datum: Any): SparkPeriod = { datum match { case s : UTF8String => SparkPeriod(s.toString()) } } override def userClass: Class[SparkPeriod] = classOf[SparkPeriod] override def asNullable: SparkPeriodUDT = this } @SQLUserDefinedType(udt = classOf[SparkIntervalUDT]) case class SparkInterval(intervalIsoStr : String) class SparkIntervalUDT extends UserDefinedType[SparkInterval] { override def sqlType: DataType = StringType override def serialize(obj: SparkInterval): Any = { obj match { case i: SparkInterval => CatalystTypeConverters.convertToCatalyst(i.intervalIsoStr) } } override def deserialize(datum: Any): SparkInterval = { datum match { case s : UTF8String => SparkInterval(s.toString()) } } override def userClass: Class[SparkInterval] = classOf[SparkInterval] override def asNullable: SparkIntervalUDT = this }
Example 69
Source File: JdbcUtil.scala From bahir with Apache License 2.0 | 5 votes |
package org.apache.bahir.sql.streaming.jdbc import java.sql.{Connection, PreparedStatement} import java.util.Locale import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object JdbcUtil { def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).getOrElse( throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) } // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for // `PreparedStatement`. The last argument `Int` means the index for the value to be set // in the SQL statement and also used for the value in `Row`. type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit def makeSetter( conn: Connection, dialect: JdbcDialect, dataType: DataType): JDBCValueSetter = dataType match { case IntegerType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getInt(pos)) case LongType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setLong(pos + 1, row.getLong(pos)) case DoubleType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setDouble(pos + 1, row.getDouble(pos)) case FloatType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setFloat(pos + 1, row.getFloat(pos)) case ShortType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getShort(pos)) case ByteType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getByte(pos)) case BooleanType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setBoolean(pos + 1, row.getBoolean(pos)) case StringType => (stmt: PreparedStatement, row: Row, pos: Int) => val strValue = row.get(pos) match { case str: UTF8String => str.toString case str: String => str } stmt.setString(pos + 1, strValue) case BinaryType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos)) case TimestampType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos)) case DateType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos)) case t: DecimalType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setBigDecimal(pos + 1, row.getDecimal(pos)) case ArrayType(et, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition .toLowerCase(Locale.ROOT).split("\\(")(0) (stmt: PreparedStatement, row: Row, pos: Int) => val array = conn.createArrayOf( typeName, row.getSeq[AnyRef](pos).toArray) stmt.setArray(pos + 1, array) case _ => (_: PreparedStatement, _: Row, pos: Int) => throw new IllegalArgumentException( s"Can't translate non-null value for field $pos") } }
Example 70
Source File: JavaConverter.scala From spark-dynamodb with Apache License 2.0 | 5 votes |
package com.audienceproject.spark.dynamodb.catalyst import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import scala.collection.JavaConverters._ object JavaConverter { def convertRowValue(row: InternalRow, index: Int, elementType: DataType): Any = { elementType match { case ArrayType(innerType, _) => convertArray(row.getArray(index), innerType) case MapType(keyType, valueType, _) => convertMap(row.getMap(index), keyType, valueType) case StructType(fields) => convertStruct(row.getStruct(index, fields.length), fields) case StringType => row.getString(index) case _ => row.get(index, elementType) } } def convertArray(array: ArrayData, elementType: DataType): Any = { elementType match { case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(convertArray(_, innerType)).asJava case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(convertMap(_, keyType, valueType)).asJava case structType: StructType => array.toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)).asJava case StringType => convertStringArray(array).asJava case _ => array.toSeq[Any](elementType).asJava } } def convertMap(map: MapData, keyType: DataType, valueType: DataType): util.Map[String, Any] = { if (keyType != StringType) throw new IllegalArgumentException( s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.") val keys = convertStringArray(map.keyArray()) val values = valueType match { case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(convertArray(_, innerType)) case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(convertMap(_, innerKeyType, innerValueType)) case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)) case StringType => convertStringArray(map.valueArray()) case _ => map.valueArray().toSeq[Any](valueType) } val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i) Map(kvPairs: _*).asJava } def convertStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = { val kvPairs = for (i <- 0 until row.numFields) yield if (row.isNullAt(i)) fields(i).name -> null else fields(i).name -> convertRowValue(row, i, fields(i).dataType) Map(kvPairs: _*).asJava } def convertStringArray(array: ArrayData): Seq[String] = array.toSeq[UTF8String](StringType).map(_.toString) }
Example 71
Source File: TypeConversion.scala From spark-dynamodb with Apache License 2.0 | 5 votes |
package com.audienceproject.spark.dynamodb.datasource import com.amazonaws.services.dynamodbv2.document.{IncompatibleTypeException, Item} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import scala.collection.JavaConverters._ private[dynamodb] object TypeConversion { def apply(attrName: String, sparkType: DataType): Item => Any = sparkType match { case BooleanType => nullableGet(_.getBOOL)(attrName) case StringType => nullableGet(item => attrName => UTF8String.fromString(item.getString(attrName)))(attrName) case IntegerType => nullableGet(_.getInt)(attrName) case LongType => nullableGet(_.getLong)(attrName) case DoubleType => nullableGet(_.getDouble)(attrName) case FloatType => nullableGet(_.getFloat)(attrName) case BinaryType => nullableGet(_.getBinary)(attrName) case DecimalType() => nullableGet(_.getNumber)(attrName) case ArrayType(innerType, _) => nullableGet(_.getList)(attrName).andThen(extractArray(convertValue(innerType))) case MapType(keyType, valueType, _) => if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.") nullableGet(_.getRawMap)(attrName).andThen(extractMap(convertValue(valueType))) case StructType(fields) => val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) }) nullableGet(_.getRawMap)(attrName).andThen(extractStruct(nestedConversions)) case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.") } private val stringConverter = (value: Any) => UTF8String.fromString(value.asInstanceOf[String]) private def convertValue(sparkType: DataType): Any => Any = sparkType match { case IntegerType => nullableConvert(_.intValue()) case LongType => nullableConvert(_.longValue()) case DoubleType => nullableConvert(_.doubleValue()) case FloatType => nullableConvert(_.floatValue()) case DecimalType() => nullableConvert(identity) case ArrayType(innerType, _) => extractArray(convertValue(innerType)) case MapType(keyType, valueType, _) => if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.") extractMap(convertValue(valueType)) case StructType(fields) => val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) }) extractStruct(nestedConversions) case BooleanType => { case boolean: Boolean => boolean case _ => null } case StringType => { case string: String => UTF8String.fromString(string) case _ => null } case BinaryType => { case byteArray: Array[Byte] => byteArray case _ => null } case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.") } private def nullableGet(getter: Item => String => Any)(attrName: String): Item => Any = { case item if item.hasAttribute(attrName) => try getter(item)(attrName) catch { case _: NumberFormatException => null case _: IncompatibleTypeException => null } case _ => null } private def nullableConvert(converter: java.math.BigDecimal => Any): Any => Any = { case item: java.math.BigDecimal => converter(item) case _ => null } private def extractArray(converter: Any => Any): Any => Any = { case list: java.util.List[_] => new GenericArrayData(list.asScala.map(converter)) case set: java.util.Set[_] => new GenericArrayData(set.asScala.map(converter).toSeq) case _ => null } private def extractMap(converter: Any => Any): Any => Any = { case map: java.util.Map[_, _] => ArrayBasedMapData(map, stringConverter, converter) case _ => null } private def extractStruct(conversions: Seq[(String, Any => Any)]): Any => Any = { case map: java.util.Map[_, _] => InternalRow.fromSeq(conversions.map({ case (name, conv) => conv(map.get(name)) })) case _ => null } }
Example 72
Source File: RowSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { val expected = new GenericInternalRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) expected.setNullAt(3) val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { val row = new SpecificInternalRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) assert(row.getAs[Seq[Int]]("b") === Seq(1)) intercept[IllegalArgumentException]{ row.getAs[Int]("c") } } test("float NaN == NaN") { val r1 = Row(Float.NaN) val r2 = Row(Float.NaN) assert(r1 === r2) } test("double NaN == NaN") { val r1 = Row(Double.NaN) val r2 = Row(Double.NaN) assert(r1 === r2) } test("equals and hashCode") { val r1 = Row("Hello") val r2 = Row("Hello") assert(r1 === r2) assert(r1.hashCode() === r2.hashCode()) val r3 = Row("World") assert(r3.hashCode() != r1.hashCode()) } }
Example 73
Source File: ColumnarTestUtils.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericInternalRow = { val row = new GenericInternalRow(length) (0 until length).foreach(row.setNullAt) row } def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) bytes } (columnType match { case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() case LONG => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) case ARRAY(_) => new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } def makeRandomValues( head: ColumnType[_], tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } def makeUniqueRandomValues[JvmType]( columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() }.drop(count).next().toSeq } def makeRandomRow( head: ColumnType[_], tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericInternalRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } row } def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => val row = new GenericInternalRow(1) row(0) = value row } (values, rows) } }
Example 74
Source File: NumberConverterSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 75
Source File: MapDataSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import scala.collection._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} import org.apache.spark.unsafe.types.UTF8String class MapDataSuite extends SparkFunSuite { test("inequality tests") { def u(str: String): UTF8String = UTF8String.fromString(str) // test data val testMap1 = Map(u("key1") -> 1) val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) val testMap3 = Map(u("key1") -> 1) val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) // ArrayBasedMapData val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) assert(testArrayMap1 !== testArrayMap3) assert(testArrayMap2 !== testArrayMap4) // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) val row = new GenericInternalRow(1) def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { row.update(0, map) val unsafeRow = unsafeConverter.apply(row) unsafeRow.getMap(0).copy } assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) } }
Example 76
Source File: DictionaryBasedEncoderSuite.scala From OAP with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources.oap.io import org.apache.parquet.bytes.BytesInput import org.apache.parquet.column.page.DictionaryPage import org.apache.parquet.column.values.dictionary.PlainValuesDictionary.PlainBinaryDictionary import org.scalacheck.{Arbitrary, Gen, Properties} import org.scalacheck.Prop.forAll import org.scalatest.prop.Checkers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.oap.adapter.PropertiesAdapter import org.apache.spark.sql.execution.datasources.oap.filecache.StringFiberBuilder import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String class DictionaryBasedEncoderCheck extends Properties("DictionaryBasedEncoder") { private val rowCountInEachGroup = Gen.choose(1, 1024) private val rowCountInLastGroup = Gen.choose(1, 1024) private val groupCount = Gen.choose(1, 100) property("Encoding/Decoding String Type") = forAll { (values: Array[String]) => forAll(rowCountInEachGroup, rowCountInLastGroup, groupCount) { (rowCount, lastCount, groupCount) => if (values.nonEmpty) { // This is the 'PLAIN' FiberBuilder to validate the 'Encoding/Decoding' // Normally, the test case should be: // values => encoded bytes => decoded bytes => decoded values (Using ColumnValues class) // Validate if 'values' and 'decoded values' are identical. // But ColumnValues only support read value form DataFile. So, we have to use another way // to validate. val referenceFiberBuilder = StringFiberBuilder(rowCount, 0) val fiberBuilder = PlainBinaryDictionaryFiberBuilder(rowCount, 0, StringType) !(0 until groupCount).exists { group => // If lastCount > rowCount, assume lastCount = rowCount val count = if (group < groupCount - 1) { rowCount } else if (lastCount > rowCount) { rowCount } else { lastCount } (0 until count).foreach { row => fiberBuilder.append(InternalRow(UTF8String.fromString(values(row % values.length)))) referenceFiberBuilder .append(InternalRow(UTF8String.fromString(values(row % values.length)))) } val bytes = fiberBuilder.build().fiberData val dictionary = new PlainBinaryDictionary( new DictionaryPage( BytesInput.from(fiberBuilder.buildDictionary), fiberBuilder.getDictionarySize, org.apache.parquet.column.Encoding.PLAIN)) val fiberParser = PlainDictionaryFiberParser( new OapDataFileMetaV1(rowCountInEachGroup = rowCount), dictionary, StringType) val parsedBytes = fiberParser.parse(bytes, count) val referenceBytes = referenceFiberBuilder.build().fiberData referenceFiberBuilder.clear() referenceFiberBuilder.resetDictionary() fiberBuilder.clear() fiberBuilder.resetDictionary() assert(parsedBytes.length == referenceBytes.length) parsedBytes.zip(referenceBytes).exists(byte => byte._1 != byte._2) } } else { true } } } } class DictionaryBasedEncoderSuite extends SparkFunSuite with Checkers { test("Check Encoding/Decoding") { check(PropertiesAdapter.getProp(new DictionaryBasedEncoderCheck())) } }
Example 77
Source File: DeltaByteArrayEncoderSuite.scala From OAP with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources.oap.io import org.scalacheck.{Arbitrary, Gen, Properties} import org.scalacheck.Prop.forAll import org.scalatest.prop.Checkers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.oap.adapter.PropertiesAdapter import org.apache.spark.sql.execution.datasources.oap.filecache.StringFiberBuilder import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String class DeltaByteArrayEncoderCheck extends Properties("DeltaByteArrayEncoder") { private val rowCountInEachGroup = Gen.choose(1, 1024) private val rowCountInLastGroup = Gen.choose(1, 1024) private val groupCount = Gen.choose(1, 100) property("Encoding/Decoding String Type") = forAll { (values: Array[String]) => forAll(rowCountInEachGroup, rowCountInLastGroup, groupCount) { (rowCount, lastCount, groupCount) => if (values.nonEmpty) { // This is the 'PLAIN' FiberBuilder to validate the 'Encoding/Decoding' // Normally, the test case should be: // values => encoded bytes => decoded bytes => decoded values (Using ColumnValues class) // Validate if 'values' and 'decoded values' are identical. // But ColumnValues only support read value form DataFile. So, we have to use another way // to validate. val referenceFiberBuilder = StringFiberBuilder(rowCount, 0) val fiberBuilder = DeltaByteArrayFiberBuilder(rowCount, 0, StringType) val fiberParser = DeltaByteArrayDataFiberParser( new OapDataFileMetaV1(rowCountInEachGroup = rowCount), StringType) !(0 until groupCount).exists { group => // If lastCount > rowCount, assume lastCount = rowCount val count = if (group < groupCount - 1) { rowCount } else if (lastCount > rowCount) { rowCount } else { lastCount } (0 until count).foreach { row => fiberBuilder.append(InternalRow(UTF8String.fromString(values(row % values.length)))) referenceFiberBuilder .append(InternalRow(UTF8String.fromString(values(row % values.length)))) } val bytes = fiberBuilder.build().fiberData val parsedBytes = fiberParser.parse(bytes, count) val referenceBytes = referenceFiberBuilder.build().fiberData referenceFiberBuilder.clear() fiberBuilder.clear() assert(parsedBytes.length == referenceBytes.length) parsedBytes.zip(referenceBytes).exists(byte => byte._1 != byte._2) } } else true } } } class DeltaByteArrayEncoderSuite extends SparkFunSuite with Checkers { test("Check Encoding/Decoding") { check(PropertiesAdapter.getProp(new DictionaryBasedEncoderCheck())) } }
Example 78
Source File: StatisticsTest.scala From OAP with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources.oap.statistics import java.io.ByteArrayOutputStream import scala.collection.mutable.ArrayBuffer import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BaseOrdering import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.sql.execution.datasources.oap.filecache.FiberCache import org.apache.spark.sql.execution.datasources.oap.index.RangeInterval import org.apache.spark.sql.execution.datasources.oap.utils.{NonNullKeyReader, NonNullKeyWriter} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.UTF8String abstract class StatisticsTest extends SparkFunSuite with BeforeAndAfterEach { protected def rowGen(i: Int): InternalRow = InternalRow(i, UTF8String.fromString(s"test#$i")) protected lazy val schema: StructType = StructType(StructField("a", IntegerType) :: StructField("b", StringType) :: Nil) @transient protected lazy val nnkw: NonNullKeyWriter = new NonNullKeyWriter(schema) @transient protected lazy val nnkr: NonNullKeyReader = new NonNullKeyReader(schema) @transient protected lazy val ordering: BaseOrdering = GenerateOrdering.create(schema) @transient protected lazy val partialOrdering: BaseOrdering = GenerateOrdering.create(StructType(schema.dropRight(1))) protected var out: ByteArrayOutputStream = _ protected var intervalArray: ArrayBuffer[RangeInterval] = new ArrayBuffer[RangeInterval]() override def beforeEach(): Unit = { out = new ByteArrayOutputStream(8000) } override def afterEach(): Unit = { out.close() intervalArray.clear() } protected def generateInterval( start: InternalRow, end: InternalRow, startInclude: Boolean, endInclude: Boolean): Unit = { intervalArray.clear() intervalArray.append(new RangeInterval(start, end, startInclude, endInclude)) } protected def checkInternalRow(row1: InternalRow, row2: InternalRow): Unit = { val res = row1 == row2 // it works.. assert(res, s"row1: $row1 does not match $row2") } protected def wrapToFiberCache(out: ByteArrayOutputStream): FiberCache = { val bytes = out.toByteArray FiberCache(bytes) } }
Example 79
Source File: NonNullKeySuite.scala From OAP with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources.oap.utils import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.oap.filecache.FiberCache import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ByteBufferOutputStream class NonNullKeySuite extends SparkFunSuite with Logging { private lazy val random = new Random(0) private lazy val values = { val booleans: Seq[Boolean] = Seq(true, false) val bytes: Seq[Byte] = Seq(Byte.MinValue, 0, 10, 30, Byte.MaxValue) val shorts: Seq[Short] = Seq(Short.MinValue, -100, 0, 10, 200, Short.MaxValue) val ints: Seq[Int] = Seq(Int.MinValue, -100, 0, 100, 12346, Int.MaxValue) val longs: Seq[Long] = Seq(Long.MinValue, -10000, 0, 20, Long.MaxValue) val floats: Seq[Float] = Seq(Float.MinValue, Float.MinPositiveValue, Float.MaxValue) val doubles: Seq[Double] = Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue) val strings: Seq[UTF8String] = Seq("", "test", "b plus tree", "BTreeRecordReaderWriter").map(UTF8String.fromString) val binaries: Seq[Array[Byte]] = (0 until 20 by 5).map{ size => val buf = new Array[Byte](size) random.nextBytes(buf) buf } val values = booleans ++ bytes ++ shorts ++ ints ++ longs ++ floats ++ doubles ++ strings ++ binaries ++ Nil random.shuffle(values) } private def toSparkDataType(any: Any): DataType = { any match { case _: Boolean => BooleanType case _: Short => ShortType case _: Byte => ByteType case _: Int => IntegerType case _: Long => LongType case _: Float => FloatType case _: Double => DoubleType case _: UTF8String => StringType case _: Array[Byte] => BinaryType } } test("Read/Write Based On Schema") { values.grouped(10).foreach { valueSeq => val schema = StructType(valueSeq.zipWithIndex.map { case (v, i) => StructField(s"col$i", toSparkDataType(v)) }) val nnkw = new NonNullKeyWriter(schema) val nnkr = new NonNullKeyReader(schema) val row = InternalRow.fromSeq(valueSeq) val buf = new ByteBufferOutputStream() nnkw.writeKey(buf, row) val answerRow = nnkr.readKey(FiberCache(buf.toByteArray), 0)._1 assert(row.equals(answerRow)) } } }
Example 80
Source File: OapDataReader.scala From OAP with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources.oap.io import org.apache.hadoop.fs.FSDataInputStream import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{OapException, PartitionedFile} import org.apache.spark.sql.execution.datasources.oap.INDEX_STAT._ import org.apache.spark.sql.execution.datasources.oap.OapFileFormat import org.apache.spark.sql.execution.datasources.oap.io.OapDataFileProperties.DataFileVersion import org.apache.spark.sql.execution.datasources.oap.io.OapDataFileProperties.DataFileVersion.DataFileVersion import org.apache.spark.unsafe.types.UTF8String abstract class OapDataReader { def read(file: PartitionedFile): Iterator[InternalRow] // The two following fields have to be defined by certain versions of OapDataReader for use in // [[OapMetricsManager]] def rowsReadByIndex: Option[Long] def indexStat: INDEX_STAT } object OapDataReader extends Logging { def readVersion(is: FSDataInputStream, fileLen: Long): DataFileVersion = { val MAGIC_VERSION_LENGTH = 4 val metaEnd = fileLen - 4 // seek to the position of data file meta length is.seek(metaEnd) val metaLength = is.readInt() // read all bytes of data file meta val magicBuffer = new Array[Byte](MAGIC_VERSION_LENGTH) is.readFully(metaEnd - metaLength, magicBuffer) val magic = UTF8String.fromBytes(magicBuffer).toString magic match { case m if ! m.contains("OAP") => throw new OapException("Not a valid Oap Data File") case m if m == "OAP1" => DataFileVersion.OAP_DATAFILE_V1 case _ => throw new OapException("Not a supported Oap Data File version") } } def getDataFileClassFor(dataReaderClassFromDataSourceMeta: String, reader: OapDataReader): String = { dataReaderClassFromDataSourceMeta match { case c if c == OapFileFormat.PARQUET_DATA_FILE_CLASSNAME => c case c if c == OapFileFormat.ORC_DATA_FILE_CLASSNAME => c case c if c == OapFileFormat.OAP_DATA_FILE_CLASSNAME => reader match { case r: OapDataReaderV1 => OapFileFormat.OAP_DATA_FILE_V1_CLASSNAME case _ => throw new OapException(s"Undefined connection for $reader") } case _ => throw new OapException( s"Undefined data reader class name $dataReaderClassFromDataSourceMeta") } } }
Example 81
Source File: NodeType.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.types import java.sql.Date import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.UTF8String class NodeType extends UserDefinedType[Node] { override val sqlType = StructType(Seq( StructField("path", ArrayType(StringType, containsNull = false), nullable = false), StructField("dataType", StringType, nullable = false), StructField("preRank", IntegerType, nullable = true), StructField("postRank", IntegerType, nullable = true), StructField("isLeaf", BooleanType, nullable = true), StructField("ordPath", ArrayType(LongType, containsNull=false), nullable = true) )) override def serialize(obj: Any): Any = obj match { case node: Node => InternalRow(new GenericArrayData(node.path.map { case null => null case p => UTF8String.fromString(p.toString) }), UTF8String.fromString(node.pathDataTypeJson), node.preRank, node.postRank, node.isLeaf, if (node.ordPath == null){ node.ordPath } else { new GenericArrayData(node.ordPath) }) case _ => throw new UnsupportedOperationException(s"Cannot serialize ${obj.getClass}") } // scalastyle:off cyclomatic.complexity override def deserialize(datum: Any): Node = datum match { case row: InternalRow => { val stringArray = row.getArray(0).toArray[UTF8String](StringType).map { case null => null case somethingElse => somethingElse.toString } val readDataTypeString: String = row.getString(1) val readDataType: DataType = DataType.fromJson(readDataTypeString) val path: Seq[Any] = readDataType match { case StringType => stringArray case LongType => stringArray.map(v => if (v != null) v.toLong else null) case IntegerType => stringArray.map(v => if (v != null) v.toInt else null) case DoubleType => stringArray.map(v => if (v != null) v.toDouble else null) case FloatType => stringArray.map(v => if (v != null) v.toFloat else null) case ByteType => stringArray.map(v => if (v != null) v.toByte else null) case BooleanType => stringArray.map(v => if (v != null) v.toBoolean else null) case TimestampType => stringArray.map(v => if (v != null) v.toLong else null) case dt: DataType => sys.error(s"Type $dt not supported for hierarchy path") } val preRank: Integer = if (row.isNullAt(2)) null else row.getInt(2) val postRank: Integer = if (row.isNullAt(3)) null else row.getInt(3) // scalastyle:off magic.number val isLeaf: java.lang.Boolean = if (row.isNullAt(4)) null else row.getBoolean(4) val ordPath: Seq[Long] = if (row.isNullAt(5)) null else row.getArray(5).toLongArray() // scalastyle:on magic.number Node( path, readDataTypeString, preRank, postRank, isLeaf, ordPath ) } case node: Node => node case _ => throw new UnsupportedOperationException(s"Cannot deserialize ${datum.getClass}") } // scalastyle:on override def userClass: java.lang.Class[Node] = classOf[Node] } case object NodeType extends NodeType
Example 82
Source File: ERPCurrencyConversionExpression.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.currency.erp import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.currency.CurrencyConversionException import org.apache.spark.sql.currency.erp.ERPConversionLoader.RConversionOptionsCurried import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import scala.util.control.NonFatal case class ERPCurrencyConversionExpression( conversionFunction: RConversionOptionsCurried, children: Seq[Expression]) extends Expression with ImplicitCastInputTypes with CodegenFallback { protected val CLIENT_INDEX = 0 protected val CONVERSION_TYPE_INDEX = 1 protected val AMOUNT_INDEX = 2 protected val FROM_INDEX = 3 protected val TO_INDEX = 4 protected val DATE_INDEX = 5 protected val NUM_ARGS = 6 protected val errorMessage = "Currency conversion library encountered an internal error" override def eval(input: InternalRow): Any = { val inputArguments = children.map(_.eval(input)) require(inputArguments.length == NUM_ARGS, "wrong number of arguments") // parse arguments val client = Option(inputArguments(CLIENT_INDEX).asInstanceOf[UTF8String]).map(_.toString) val conversionType = Option(inputArguments(CONVERSION_TYPE_INDEX).asInstanceOf[UTF8String]).map(_.toString) val amount = Option(inputArguments(AMOUNT_INDEX).asInstanceOf[Decimal].toJavaBigDecimal) val sourceCurrency = Option(inputArguments(FROM_INDEX).asInstanceOf[UTF8String]).map(_.toString) val targetCurrency = Option(inputArguments(TO_INDEX).asInstanceOf[UTF8String]).map(_.toString) val date = Option(inputArguments(DATE_INDEX).asInstanceOf[UTF8String]).map(_.toString) // perform conversion val conversion = conversionFunction(client, conversionType, sourceCurrency, targetCurrency, date) val resultTry = conversion(amount) // If 'resultTry' holds a 'Failure', we have to propagate it because potential failure // handling already took place. We just wrap it in case it is a cryptic error. resultTry.recover { case NonFatal(err) => throw new CurrencyConversionException(errorMessage, err) }.get.map(Decimal.apply).orNull } override def dataType: DataType = DecimalType.forType(DoubleType) override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, DecimalType, StringType, StringType, StringType) def inputNames: Seq[String] = Seq("client", "conversion_type", "amount", "source", "target", "date") def getChild(name: String): Option[Expression] = { inputNames.zip(children).find { case (n, _) => name == n }.map(_._2) } }
Example 83
Source File: AnnotationParser.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.parser import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{AnnotationReference, Expression, Literal} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String protected def toTableMetadata(metadata: Map[String, Expression]): Metadata = { val res = new MetadataBuilder() metadata.foreach { case (k, v:Literal) => v.dataType match { case StringType => if (k.equals("?")) { sys.error("column metadata key can not be ?") } if (k.equals("*")) { sys.error("column metadata key can not be *") } res.putString(k, v.value.asInstanceOf[UTF8String].toString) case LongType => res.putLong(k, v.value.asInstanceOf[Long]) case DoubleType => res.putDouble(k, v.value.asInstanceOf[Double]) case NullType => res.putString(k, null) case a:ArrayType => res.putString(k, v.value.toString) } case (k, v:AnnotationReference) => sys.error("column metadata can not have a reference to another column metadata") } res.build() } }
Example 84
Source File: DataSourceTest.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String private[sql] abstract class DataSourceTest extends QueryTest { protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) { test(sqlString) { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) { checkAnswer(spark.sql(sqlString), expectedAnswer) } } } } class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { SimpleDDLScan( parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext.sparkSession) } } case class SimpleDDLScan( from: Int, to: Int, table: String)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), StructField("floatType", FloatType, nullable = false), StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType", StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sparkSession.sparkContext.parallelize(from to to).map { e => InternalRow(UTF8String.fromString(s"people$e"), e * 2) }.asInstanceOf[RDD[Row]] } }
Example 85
Source File: RowSuite.scala From XSQL with Apache License 2.0 | 4 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { val expected = new GenericInternalRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) expected.setNullAt(3) val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { val row = new SpecificInternalRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) assert(row.getAs[Seq[Int]]("b") === Seq(1)) intercept[IllegalArgumentException]{ row.getAs[Int]("c") } } test("float NaN == NaN") { val r1 = Row(Float.NaN) val r2 = Row(Float.NaN) assert(r1 === r2) } test("double NaN == NaN") { val r1 = Row(Double.NaN) val r2 = Row(Double.NaN) assert(r1 === r2) } test("equals and hashCode") { val r1 = Row("Hello") val r2 = Row("Hello") assert(r1 === r2) assert(r1.hashCode() === r2.hashCode()) val r3 = Row("World") assert(r3.hashCode() != r1.hashCode()) } }
Example 86
Source File: RowSuite.scala From drizzle-spark with Apache License 2.0 | 4 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { val expected = new GenericInternalRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) expected.setNullAt(3) val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { val row = new SpecificInternalRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) assert(row.getAs[Seq[Int]]("b") === Seq(1)) intercept[IllegalArgumentException]{ row.getAs[Int]("c") } } test("float NaN == NaN") { val r1 = Row(Float.NaN) val r2 = Row(Float.NaN) assert(r1 === r2) } test("double NaN == NaN") { val r1 = Row(Double.NaN) val r2 = Row(Double.NaN) assert(r1 === r2) } test("equals and hashCode") { val r1 = Row("Hello") val r2 = Row("Hello") assert(r1 === r2) assert(r1.hashCode() === r2.hashCode()) val r3 = Row("World") assert(r3.hashCode() != r1.hashCode()) } }
Example 87
Source File: CarbonCountStar.scala From carbondata with Apache License 2.0 | 4 votes |
package org.apache.spark.sql import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.optimizer.CarbonFilters import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String import org.apache.carbondata.core.datastore.impl.FileFactory import org.apache.carbondata.core.metadata.AbsoluteTableIdentifier import org.apache.carbondata.core.metadata.schema.table.CarbonTable import org.apache.carbondata.core.mutate.CarbonUpdateUtil import org.apache.carbondata.core.statusmanager.StageInputCollector import org.apache.carbondata.core.util.{CarbonProperties, ThreadLocalSessionInfo} import org.apache.carbondata.hadoop.api.{CarbonInputFormat, CarbonTableInputFormat} import org.apache.carbondata.hadoop.util.CarbonInputFormatUtil import org.apache.carbondata.spark.load.DataLoadProcessBuilderOnSpark case class CarbonCountStar( attributesRaw: Seq[Attribute], carbonTable: CarbonTable, sparkSession: SparkSession, outUnsafeRows: Boolean = true) extends LeafExecNode { override def doExecute(): RDD[InternalRow] = { ThreadLocalSessionInfo .setConfigurationToCurrentThread(sparkSession.sessionState.newHadoopConf()) val absoluteTableIdentifier = carbonTable.getAbsoluteTableIdentifier val (job, tableInputFormat) = createCarbonInputFormat(absoluteTableIdentifier) CarbonInputFormat.setQuerySegment(job.getConfiguration, carbonTable) // get row count var rowCount = CarbonUpdateUtil.getRowCount( tableInputFormat.getBlockRowCount( job, carbonTable, CarbonFilters.getPartitions( Seq.empty, sparkSession, TableIdentifier( carbonTable.getTableName, Some(carbonTable.getDatabaseName))).map(_.asJava).orNull, false), carbonTable) if (CarbonProperties.isQueryStageInputEnabled) { // check for number of row for stage input val splits = StageInputCollector.createInputSplits(carbonTable, job.getConfiguration) if (!splits.isEmpty) { val df = DataLoadProcessBuilderOnSpark.createInputDataFrame( sparkSession, carbonTable, splits.asScala) rowCount += df.count() } } val valueRaw = attributesRaw.head.dataType match { case StringType => Seq(UTF8String.fromString(Long.box(rowCount).toString)).toArray .asInstanceOf[Array[Any]] case _ => Seq(Long.box(rowCount)).toArray.asInstanceOf[Array[Any]] } val value = new GenericInternalRow(valueRaw) val unsafeProjection = UnsafeProjection.create(output.map(_.dataType).toArray) val row = if (outUnsafeRows) unsafeProjection(value) else value sparkContext.parallelize(Seq(row)) } override def output: Seq[Attribute] = { attributesRaw } private def createCarbonInputFormat(absoluteTableIdentifier: AbsoluteTableIdentifier ): (Job, CarbonTableInputFormat[Array[Object]]) = { val carbonInputFormat = new CarbonTableInputFormat[Array[Object]]() val jobConf: JobConf = new JobConf(FileFactory.getConfiguration) SparkHadoopUtil.get.addCredentials(jobConf) CarbonInputFormat.setTableInfo(jobConf, carbonTable.getTableInfo) val job = new Job(jobConf) FileInputFormat.addInputPath(job, new Path(absoluteTableIdentifier.getTablePath)) CarbonInputFormat .setTransactionalTable(job.getConfiguration, carbonTable.getTableInfo.isTransactionalTable) CarbonInputFormatUtil.setIndexJobIfConfigured(job.getConfiguration) (job, carbonInputFormat) } }
Example 88
Source File: BasicCurrencyConversionExpression.scala From HANAVora-Extensions with Apache License 2.0 | 4 votes |
package org.apache.spark.sql.currency.basic import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String case class BasicCurrencyConversionExpression( conversion: BasicCurrencyConversion, children: Seq[Expression]) extends Expression with ImplicitCastInputTypes with CodegenFallback { protected val AMOUNT_INDEX = 0 protected val FROM_INDEX = 1 protected val TO_INDEX = 2 protected val DATE_INDEX = 3 protected val NUM_ARGS = 4 override def eval(input: InternalRow): Any = { val inputArguments = children.map(_.eval(input)) require(inputArguments.length == NUM_ARGS, "wrong number of arguments") val sourceCurrency = Option(inputArguments(FROM_INDEX).asInstanceOf[UTF8String]).map(_.toString) val targetCurrency = Option(inputArguments(TO_INDEX).asInstanceOf[UTF8String]).map(_.toString) val amount = Option(inputArguments(AMOUNT_INDEX).asInstanceOf[Decimal].toJavaBigDecimal) val date = Option(inputArguments(DATE_INDEX).asInstanceOf[UTF8String]).map(_.toString) (amount, sourceCurrency, targetCurrency, date) match { case (Some(a), Some(s), Some(t), Some(d)) => nullSafeEval(a, s, t, d) case _ => null } } def nullSafeEval(amount: java.math.BigDecimal, sourceCurrency: String, targetCurrency: String, date: String): Any = { conversion.convert(amount, sourceCurrency, targetCurrency, date) .get .map(Decimal.apply) .orNull } override def dataType: DataType = DecimalType.forType(DoubleType) override def nullable: Boolean = true // TODO(MD, CS): use DateType but support date string override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringType, StringType, StringType) }