org.apache.spark.sql.catalyst.util.ArrayData Scala Examples
The following examples show how to use org.apache.spark.sql.catalyst.util.ArrayData.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: ArrayType.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.types import scala.math.Ordering import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.util.ArrayData override def defaultSize: Int = 100 * elementType.defaultSize override def simpleString: String = s"array<${elementType.simpleString}>" override def catalogString: String = s"array<${elementType.catalogString}>" override def sql: String = s"ARRAY<${elementType.sql}>" override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } @transient private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { private[this] val elementOrdering: Ordering[Any] = elementType match { case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => throw new IllegalArgumentException(s"Type $other does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { val leftArray = x val rightArray = y val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements()) var i = 0 while (i < minLength) { val isNullLeft = leftArray.isNullAt(i) val isNullRight = rightArray.isNullAt(i) if (isNullLeft && isNullRight) { // Do nothing. } else if (isNullLeft) { return -1 } else if (isNullRight) { return 1 } else { val comp = elementOrdering.compare( leftArray.get(i, elementType), rightArray.get(i, elementType)) if (comp != 0) { return comp } } i += 1 } if (leftArray.numElements() < rightArray.numElements()) { return -1 } else if (leftArray.numElements() > rightArray.numElements()) { return 1 } else { return 0 } } } }
Example 2
Source File: CardinalityHashFunctionTest.scala From spark-alchemy with Apache License 2.0 | 5 votes |
package com.swoop.alchemy.spark.expressions.hll import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.scalatest.{Matchers, WordSpec} class CardinalityHashFunctionTest extends WordSpec with Matchers { "Cardinality hash functions" should { "account for nulls" in { val a = UTF8String.fromString("a") allDistinct(Seq( null, Array.empty[Byte], Array.apply(1.toByte) ), BinaryType) allDistinct(Seq( null, UTF8String.fromString(""), a ), StringType) allDistinct(Seq( null, ArrayData.toArrayData(Array.empty), ArrayData.toArrayData(Array(null)), ArrayData.toArrayData(Array(null, null)), ArrayData.toArrayData(Array(a, null)), ArrayData.toArrayData(Array(null, a)) ), ArrayType(StringType)) allDistinct(Seq( null, ArrayBasedMapData(Map.empty), ArrayBasedMapData(Map(null.asInstanceOf[String] -> null)) ), MapType(StringType, StringType)) allDistinct(Seq( null, InternalRow(null), InternalRow(a) ), new StructType().add("foo", StringType)) allDistinct(Seq( InternalRow(null, a), InternalRow(a, null) ), new StructType().add("foo", StringType).add("bar", StringType)) } } def allDistinct(values: Seq[Any], dataType: DataType): Unit = { val hashed = values.map(x => CardinalityXxHash64Function.hash(x, dataType, 0)) hashed.distinct.length should be(hashed.length) } }
Example 3
Source File: ArrayType.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.types import scala.math.Ordering import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.util.ArrayData override def defaultSize: Int = 1 * elementType.defaultSize override def simpleString: String = s"array<${elementType.simpleString}>" override def catalogString: String = s"array<${elementType.catalogString}>" override def sql: String = s"ARRAY<${elementType.sql}>" override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } @transient private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { private[this] val elementOrdering: Ordering[Any] = elementType match { case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => throw new IllegalArgumentException( s"Type ${other.catalogString} does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { val leftArray = x val rightArray = y val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements()) var i = 0 while (i < minLength) { val isNullLeft = leftArray.isNullAt(i) val isNullRight = rightArray.isNullAt(i) if (isNullLeft && isNullRight) { // Do nothing. } else if (isNullLeft) { return -1 } else if (isNullRight) { return 1 } else { val comp = elementOrdering.compare( leftArray.get(i, elementType), rightArray.get(i, elementType)) if (comp != 0) { return comp } } i += 1 } if (leftArray.numElements() < rightArray.numElements()) { return -1 } else if (leftArray.numElements() > rightArray.numElements()) { return 1 } else { return 0 } } } }
Example 4
Source File: InternalRow.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { case BooleanType => (input, ordinal) => input.getBoolean(ordinal) case ByteType => (input, ordinal) => input.getByte(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal) case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal) case DoubleType => (input, ordinal) => input.getDouble(ordinal) case StringType => (input, ordinal) => input.getUTF8String(ordinal) case BinaryType => (input, ordinal) => input.getBinary(ordinal) case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) case _: ArrayType => (input, ordinal) => input.getArray(ordinal) case _: MapType => (input, ordinal) => input.getMap(ordinal) case u: UserDefinedType[_] => getAccessor(u.sqlType) case _ => (input, ordinal) => input.get(ordinal, dataType) } }
Example 5
Source File: JavaConverter.scala From spark-dynamodb with Apache License 2.0 | 5 votes |
package com.audienceproject.spark.dynamodb.catalyst import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import scala.collection.JavaConverters._ object JavaConverter { def convertRowValue(row: InternalRow, index: Int, elementType: DataType): Any = { elementType match { case ArrayType(innerType, _) => convertArray(row.getArray(index), innerType) case MapType(keyType, valueType, _) => convertMap(row.getMap(index), keyType, valueType) case StructType(fields) => convertStruct(row.getStruct(index, fields.length), fields) case StringType => row.getString(index) case _ => row.get(index, elementType) } } def convertArray(array: ArrayData, elementType: DataType): Any = { elementType match { case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(convertArray(_, innerType)).asJava case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(convertMap(_, keyType, valueType)).asJava case structType: StructType => array.toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)).asJava case StringType => convertStringArray(array).asJava case _ => array.toSeq[Any](elementType).asJava } } def convertMap(map: MapData, keyType: DataType, valueType: DataType): util.Map[String, Any] = { if (keyType != StringType) throw new IllegalArgumentException( s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.") val keys = convertStringArray(map.keyArray()) val values = valueType match { case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(convertArray(_, innerType)) case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(convertMap(_, innerKeyType, innerValueType)) case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)) case StringType => convertStringArray(map.valueArray()) case _ => map.valueArray().toSeq[Any](valueType) } val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i) Map(kvPairs: _*).asJava } def convertStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = { val kvPairs = for (i <- 0 until row.numFields) yield if (row.isNullAt(i)) fields(i).name -> null else fields(i).name -> convertRowValue(row, i, fields(i).dataType) Map(kvPairs: _*).asJava } def convertStringArray(array: ArrayData): Seq[String] = array.toSeq[UTF8String](StringType).map(_.toString) }
Example 6
Source File: GenerateUnsafeProjectionSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenerateUnsafeProjectionSuite extends SparkFunSuite { test("Test unsafe projection string access pattern") { val dataType = (new StructType).add("a", StringType) val exprs = BoundReference(0, dataType, nullable = true) :: Nil val projection = GenerateUnsafeProjection.generate(exprs) val result = projection.apply(InternalRow(AlwaysNull)) assert(!result.isNullAt(0)) assert(result.getStruct(0, 1).isNullAt(0)) } } object AlwaysNull extends InternalRow { override def numFields: Int = 1 override def setNullAt(i: Int): Unit = {} override def copy(): InternalRow = this override def anyNull: Boolean = true override def isNullAt(ordinal: Int): Boolean = true override def update(i: Int, value: Any): Unit = notSupported override def getBoolean(ordinal: Int): Boolean = notSupported override def getByte(ordinal: Int): Byte = notSupported override def getShort(ordinal: Int): Short = notSupported override def getInt(ordinal: Int): Int = notSupported override def getLong(ordinal: Int): Long = notSupported override def getFloat(ordinal: Int): Float = notSupported override def getDouble(ordinal: Int): Double = notSupported override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported override def getUTF8String(ordinal: Int): UTF8String = notSupported override def getBinary(ordinal: Int): Array[Byte] = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = notSupported override def getMap(ordinal: Int): MapData = notSupported override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported private def notSupported: Nothing = throw new UnsupportedOperationException }
Example 7
Source File: JacksonGenerator.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils} import scala.collection.Map import com.fasterxml.jackson.core._ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ private[sql] object JacksonGenerator { def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v) => gen.writeString(v.toString) case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) case (FloatType, v: Float) => gen.writeNumber(v) case (DoubleType, v: Double) => gen.writeNumber(v) case (LongType, v: Long) => gen.writeNumber(v) case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal) case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString) // For UDT values, they should be in the SQL type's corresponding value type. // We should not see values in the user-defined class at here. // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is // an ArrayData at here, instead of a Vector. case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) case (ArrayType(ty, _), v: ArrayData) => gen.writeStartArray() v.foreach(ty, (_, value) => valWriter(ty, value)) gen.writeEndArray() case (MapType(kt, vt, _), v: MapData) => gen.writeStartObject() v.foreach(kt, vt, { (k, v) => gen.writeFieldName(k.toString) valWriter(vt, v) }) gen.writeEndObject() case (StructType(ty), v: InternalRow) => gen.writeStartObject() var i = 0 while (i < ty.length) { val field = ty(i) val value = v.get(i, field.dataType) if (value != null) { gen.writeFieldName(field.name) valWriter(field.dataType, value) } i += 1 } gen.writeEndObject() case (dt, v) => sys.error( s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) } }
Example 8
Source File: ExamplePointUDT.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.test import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" override def serialize(obj: Any): GenericArrayData = { obj match { case p: ExamplePoint => val output = new Array[Any](2) output(0) = p.x output(1) = p.y new GenericArrayData(output) } } override def deserialize(datum: Any): ExamplePoint = { datum match { case values: ArrayData => new ExamplePoint(values.getDouble(0), values.getDouble(1)) } } override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] private[spark] override def asNullable: ExamplePointUDT = this }
Example 9
Source File: MeanSubstitute.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.sql.expressions import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String import io.projectglow.sql.dsl._ import io.projectglow.sql.util.RewriteAfterResolution case class MeanSubstitute(array: Expression, missingValue: Expression) extends RewriteAfterResolution { override def children: Seq[Expression] = Seq(array, missingValue) def this(array: Expression) = { this(array, Literal(-1)) } private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType // A value is considered missing if it is NaN, null or equal to the missing value parameter def isMissing(arrayElement: Expression): Predicate = IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = { val sumName = Literal(UTF8String.fromString("sum"), StringType) val countName = Literal(UTF8String.fromString("count"), StringType) namedStruct(sumName, sumValue, countName, countValue) } // Update sum and count with array element if not missing def updateSumAndCountConditionally( stateStruct: Expression, arrayElement: Expression): Expression = { If( isMissing(arrayElement), // If value is missing, do not update sum and count stateStruct, // If value is not missing, add to sum and increment count createNamedStruct( stateStruct.getField("sum") + arrayElement, stateStruct.getField("count") + 1) ) } // Calculate mean for imputation def calculateMean(stateStruct: Expression): Expression = { If( stateStruct.getField("count") > 0, // If non-missing values were found, calculate the average stateStruct.getField("sum") / stateStruct.getField("count"), // If all values were missing, substitute with missing value missingValue ) } lazy val arrayMean: Expression = { // Sum and count of non-missing values array.aggregate( createNamedStruct(Literal(0d), Literal(0L)), updateSumAndCountConditionally, calculateMean ) } def substituteWithMean(arrayElement: Expression): Expression = { If(isMissing(arrayElement), arrayMean, arrayElement) } override def rewrite: Expression = { if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) { throw SQLUtils.newAnalysisException( s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.") } if (!missingValue.dataType.isInstanceOf[NumericType]) { throw SQLUtils.newAnalysisException( s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.") } // Replace missing values with the provided strategy array.arrayTransform(substituteWithMean(_)) } }
Example 10
Source File: LinearRegressionExpr.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.sql.expressions import breeze.linalg.DenseVector import org.apache.spark.TaskContext import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, TernaryExpression} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ object LinearRegressionExpr { private val matrixUDT = SQLUtils.newMatrixUDT() private val state = new ThreadLocal[CovariateQRContext] def doLinearRegression(genotypes: Any, phenotypes: Any, covariates: Any): InternalRow = { if (state.get() == null) { // Save the QR factorization of the covariate matrix since it's the same for every row state.set(CovariateQRContext.computeQR(matrixUDT.deserialize(covariates).toDense)) TaskContext.get().addTaskCompletionListener[Unit](_ => state.remove()) } LinearRegressionGwas.linearRegressionGwas( new DenseVector[Double](genotypes.asInstanceOf[ArrayData].toDoubleArray()), new DenseVector[Double](phenotypes.asInstanceOf[ArrayData].toDoubleArray()), state.get() ) } } case class LinearRegressionExpr( genotypes: Expression, phenotypes: Expression, covariates: Expression) extends TernaryExpression with ImplicitCastInputTypes { private val matrixUDT = SQLUtils.newMatrixUDT() override def dataType: DataType = StructType( Seq( StructField("beta", DoubleType), StructField("standardError", DoubleType), StructField("pValue", DoubleType))) override def inputTypes: Seq[DataType] = Seq(ArrayType(DoubleType), ArrayType(DoubleType), matrixUDT) override def children: Seq[Expression] = Seq(genotypes, phenotypes, covariates) override protected def nullSafeEval(genotypes: Any, phenotypes: Any, covariates: Any): Any = { LinearRegressionExpr.doLinearRegression(genotypes, phenotypes, covariates) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen( ctx, ev, (genotypes, phenotypes, covariates) => { s""" |${ev.value} = io.projectglow.sql.expressions.LinearRegressionExpr.doLinearRegression($genotypes, $phenotypes, $covariates); """.stripMargin } ) } }
Example 11
Source File: ShapeType.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba import org.apache.spark.sql.types._ import org.apache.spark.sql.simba.spatial.Shape import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} private[simba] class ShapeType extends UserDefinedType[Shape] { override def sqlType: DataType = ArrayType(ByteType, containsNull = false) override def serialize(s: Shape): Any = { new GenericArrayData(ShapeSerializer.serialize(s)) } override def userClass: Class[Shape] = classOf[Shape] override def deserialize(datum: Any): Shape = { datum match { case values: ArrayData => ShapeSerializer.deserialize(values.toByteArray) } } } case object ShapeType extends ShapeType