org.apache.spark.sql.catalyst.util.ArrayData Scala Examples

Example 1
Source File: ArrayType.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.types

import scala.math.Ordering

import org.json4s.JsonDSL._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.util.ArrayData

  override def defaultSize: Int = 100 * elementType.defaultSize

  override def simpleString: String = s"array<${elementType.simpleString}>"

  override def catalogString: String = s"array<${elementType.catalogString}>"

  override def sql: String = s"ARRAY<${elementType.sql}>"

  override private[spark] def asNullable: ArrayType =
    ArrayType(elementType.asNullable, containsNull = true)

  override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
    f(this) || elementType.existsRecursively(f)

  private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
    private[this] val elementOrdering: Ordering[Any] = elementType match {
      case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
      case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case other =>
        throw new IllegalArgumentException(s"Type $other does not support ordered operations")

    def compare(x: ArrayData, y: ArrayData): Int = {
      val leftArray = x
      val rightArray = y
      val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
      var i = 0
      while (i < minLength) {
        val isNullLeft = leftArray.isNullAt(i)
        val isNullRight = rightArray.isNullAt(i)
        if (isNullLeft && isNullRight) {
          // Do nothing.
        } else if (isNullLeft) {
          return -1
        } else if (isNullRight) {
          return 1
        } else {
          val comp =
              leftArray.get(i, elementType),
              rightArray.get(i, elementType))
          if (comp != 0) {
            return comp
        i += 1
      if (leftArray.numElements() < rightArray.numElements()) {
        return -1
      } else if (leftArray.numElements() > rightArray.numElements()) {
        return 1
      } else {
        return 0
Example 2
Source File: CardinalityHashFunctionTest.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package com.swoop.alchemy.spark.expressions.hll

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.scalatest.{Matchers, WordSpec}

class CardinalityHashFunctionTest extends WordSpec with Matchers {

  "Cardinality hash functions" should {
    "account for nulls" in {
      val a = UTF8String.fromString("a")

      ), BinaryType)

      ), StringType)

        ArrayData.toArrayData(Array(null, null)),
        ArrayData.toArrayData(Array(a, null)),
        ArrayData.toArrayData(Array(null, a))
      ), ArrayType(StringType))

        ArrayBasedMapData(Map(null.asInstanceOf[String] -> null))
      ), MapType(StringType, StringType))

      ), new StructType().add("foo", StringType))

        InternalRow(null, a),
        InternalRow(a, null)
      ), new StructType().add("foo", StringType).add("bar", StringType))

  def allDistinct(values: Seq[Any], dataType: DataType): Unit = {
    val hashed = => CardinalityXxHash64Function.hash(x, dataType, 0))
    hashed.distinct.length should be(hashed.length)

Example 3
Source File: ArrayType.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.types

import scala.math.Ordering

import org.json4s.JsonDSL._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.util.ArrayData

  override def defaultSize: Int = 1 * elementType.defaultSize

  override def simpleString: String = s"array<${elementType.simpleString}>"

  override def catalogString: String = s"array<${elementType.catalogString}>"

  override def sql: String = s"ARRAY<${elementType.sql}>"

  override private[spark] def asNullable: ArrayType =
    ArrayType(elementType.asNullable, containsNull = true)

  override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
    f(this) || elementType.existsRecursively(f)

  private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
    private[this] val elementOrdering: Ordering[Any] = elementType match {
      case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
      case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case other =>
        throw new IllegalArgumentException(
          s"Type ${other.catalogString} does not support ordered operations")

    def compare(x: ArrayData, y: ArrayData): Int = {
      val leftArray = x
      val rightArray = y
      val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
      var i = 0
      while (i < minLength) {
        val isNullLeft = leftArray.isNullAt(i)
        val isNullRight = rightArray.isNullAt(i)
        if (isNullLeft && isNullRight) {
          // Do nothing.
        } else if (isNullLeft) {
          return -1
        } else if (isNullRight) {
          return 1
        } else {
          val comp =
              leftArray.get(i, elementType),
              rightArray.get(i, elementType))
          if (comp != 0) {
            return comp
        i += 1
      if (leftArray.numElements() < rightArray.numElements()) {
        return -1
      } else if (leftArray.numElements() > rightArray.numElements()) {
        return 1
      } else {
        return 0
Example 4
Source File: InternalRow.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

  def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
    case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
    case ByteType => (input, ordinal) => input.getByte(ordinal)
    case ShortType => (input, ordinal) => input.getShort(ordinal)
    case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
    case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
    case FloatType => (input, ordinal) => input.getFloat(ordinal)
    case DoubleType => (input, ordinal) => input.getDouble(ordinal)
    case StringType => (input, ordinal) => input.getUTF8String(ordinal)
    case BinaryType => (input, ordinal) => input.getBinary(ordinal)
    case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
    case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
    case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
    case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
    case _: MapType => (input, ordinal) => input.getMap(ordinal)
    case u: UserDefinedType[_] => getAccessor(u.sqlType)
    case _ => (input, ordinal) => input.get(ordinal, dataType)
Example 5
Source File: JavaConverter.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb.catalyst

import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.JavaConverters._

object JavaConverter {

    def convertRowValue(row: InternalRow, index: Int, elementType: DataType): Any = {
        elementType match {
            case ArrayType(innerType, _) => convertArray(row.getArray(index), innerType)
            case MapType(keyType, valueType, _) => convertMap(row.getMap(index), keyType, valueType)
            case StructType(fields) => convertStruct(row.getStruct(index, fields.length), fields)
            case StringType => row.getString(index)
            case _ => row.get(index, elementType)

    def convertArray(array: ArrayData, elementType: DataType): Any = {
        elementType match {
            case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(convertArray(_, innerType)).asJava
            case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(convertMap(_, keyType, valueType)).asJava
            case structType: StructType => array.toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)).asJava
            case StringType => convertStringArray(array).asJava
            case _ => array.toSeq[Any](elementType).asJava

    def convertMap(map: MapData, keyType: DataType, valueType: DataType): util.Map[String, Any] = {
        if (keyType != StringType) throw new IllegalArgumentException(
            s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
        val keys = convertStringArray(map.keyArray())
        val values = valueType match {
            case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(convertArray(_, innerType))
            case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(convertMap(_, innerKeyType, innerValueType))
            case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(convertStruct(_, structType.fields))
            case StringType => convertStringArray(map.valueArray())
            case _ => map.valueArray().toSeq[Any](valueType)
        val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i)
        Map(kvPairs: _*).asJava

    def convertStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = {
        val kvPairs = for (i <- 0 until row.numFields) yield
            if (row.isNullAt(i)) fields(i).name -> null
            else fields(i).name -> convertRowValue(row, i, fields(i).dataType)
        Map(kvPairs: _*).asJava

    def convertStringArray(array: ArrayData): Seq[String] =

Example 6
Source File: GenerateUnsafeProjectionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class GenerateUnsafeProjectionSuite extends SparkFunSuite {
  test("Test unsafe projection string access pattern") {
    val dataType = (new StructType).add("a", StringType)
    val exprs = BoundReference(0, dataType, nullable = true) :: Nil
    val projection = GenerateUnsafeProjection.generate(exprs)
    val result = projection.apply(InternalRow(AlwaysNull))
    assert(result.getStruct(0, 1).isNullAt(0))

object AlwaysNull extends InternalRow {
  override def numFields: Int = 1
  override def setNullAt(i: Int): Unit = {}
  override def copy(): InternalRow = this
  override def anyNull: Boolean = true
  override def isNullAt(ordinal: Int): Boolean = true
  override def update(i: Int, value: Any): Unit = notSupported
  override def getBoolean(ordinal: Int): Boolean = notSupported
  override def getByte(ordinal: Int): Byte = notSupported
  override def getShort(ordinal: Int): Short = notSupported
  override def getInt(ordinal: Int): Int = notSupported
  override def getLong(ordinal: Int): Long = notSupported
  override def getFloat(ordinal: Int): Float = notSupported
  override def getDouble(ordinal: Int): Double = notSupported
  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
  override def getUTF8String(ordinal: Int): UTF8String = notSupported
  override def getBinary(ordinal: Int): Array[Byte] = notSupported
  override def getInterval(ordinal: Int): CalendarInterval = notSupported
  override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
  override def getArray(ordinal: Int): ArrayData = notSupported
  override def getMap(ordinal: Int): MapData = notSupported
  override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
  private def notSupported: Nothing = throw new UnsupportedOperationException
Example 7
Source File: JacksonGenerator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.json

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils}

import scala.collection.Map

import com.fasterxml.jackson.core._

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

private[sql] object JacksonGenerator {
  def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = {
    def valWriter: (DataType, Any) => Unit = {
      case (_, null) | (NullType, _) => gen.writeNull()
      case (StringType, v) => gen.writeString(v.toString)
      case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString)
      case (IntegerType, v: Int) => gen.writeNumber(v)
      case (ShortType, v: Short) => gen.writeNumber(v)
      case (FloatType, v: Float) => gen.writeNumber(v)
      case (DoubleType, v: Double) => gen.writeNumber(v)
      case (LongType, v: Long) => gen.writeNumber(v)
      case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal)
      case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
      case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
      case (BooleanType, v: Boolean) => gen.writeBoolean(v)
      case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString)
      // For UDT values, they should be in the SQL type's corresponding value type.
      // We should not see values in the user-defined class at here.
      // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is
      // an ArrayData at here, instead of a Vector.
      case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)

      case (ArrayType(ty, _), v: ArrayData) =>
        v.foreach(ty, (_, value) => valWriter(ty, value))

      case (MapType(kt, vt, _), v: MapData) =>
        v.foreach(kt, vt, { (k, v) =>
          valWriter(vt, v)

      case (StructType(ty), v: InternalRow) =>
        var i = 0
        while (i < ty.length) {
          val field = ty(i)
          val value = v.get(i, field.dataType)
          if (value != null) {
            valWriter(field.dataType, value)
          i += 1

      case (dt, v) =>
          s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.")

    valWriter(rowSchema, row)
Example 8
Source File: ExamplePointUDT.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.types._

private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

  override def sqlType: DataType = ArrayType(DoubleType, false)

  override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

  override def serialize(obj: Any): GenericArrayData = {
    obj match {
      case p: ExamplePoint =>
        val output = new Array[Any](2)
        output(0) = p.x
        output(1) = p.y
        new GenericArrayData(output)

  override def deserialize(datum: Any): ExamplePoint = {
    datum match {
      case values: ArrayData =>
        new ExamplePoint(values.getDouble(0), values.getDouble(1))

  override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]

  private[spark] override def asNullable: ExamplePointUDT = this
Example 9
Source File: MeanSubstitute.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.sql.dsl._
import io.projectglow.sql.util.RewriteAfterResolution

case class MeanSubstitute(array: Expression, missingValue: Expression)
    extends RewriteAfterResolution {

  override def children: Seq[Expression] = Seq(array, missingValue)

  def this(array: Expression) = {
    this(array, Literal(-1))

  private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType

  // A value is considered missing if it is NaN, null or equal to the missing value parameter
  def isMissing(arrayElement: Expression): Predicate =
    IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue

  def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = {
    val sumName = Literal(UTF8String.fromString("sum"), StringType)
    val countName = Literal(UTF8String.fromString("count"), StringType)
    namedStruct(sumName, sumValue, countName, countValue)

  // Update sum and count with array element if not missing
  def updateSumAndCountConditionally(
      stateStruct: Expression,
      arrayElement: Expression): Expression = {
      // If value is missing, do not update sum and count
      // If value is not missing, add to sum and increment count
        stateStruct.getField("sum") + arrayElement,
        stateStruct.getField("count") + 1)

  // Calculate mean for imputation
  def calculateMean(stateStruct: Expression): Expression = {
      stateStruct.getField("count") > 0,
      // If non-missing values were found, calculate the average
      stateStruct.getField("sum") / stateStruct.getField("count"),
      // If all values were missing, substitute with missing value

  lazy val arrayMean: Expression = {
    // Sum and count of non-missing values
      createNamedStruct(Literal(0d), Literal(0L)),

  def substituteWithMean(arrayElement: Expression): Expression = {
    If(isMissing(arrayElement), arrayMean, arrayElement)

  override def rewrite: Expression = {
    if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.")

    if (!missingValue.dataType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.")

    // Replace missing values with the provided strategy
Example 10
Source File: LinearRegressionExpr.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import breeze.linalg.DenseVector
import org.apache.spark.TaskContext
import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, TernaryExpression}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._

object LinearRegressionExpr {
  private val matrixUDT = SQLUtils.newMatrixUDT()
  private val state = new ThreadLocal[CovariateQRContext]

  def doLinearRegression(genotypes: Any, phenotypes: Any, covariates: Any): InternalRow = {

    if (state.get() == null) {
      // Save the QR factorization of the covariate matrix since it's the same for every row
      TaskContext.get().addTaskCompletionListener[Unit](_ => state.remove())

      new DenseVector[Double](genotypes.asInstanceOf[ArrayData].toDoubleArray()),
      new DenseVector[Double](phenotypes.asInstanceOf[ArrayData].toDoubleArray()),

case class LinearRegressionExpr(
    genotypes: Expression,
    phenotypes: Expression,
    covariates: Expression)
    extends TernaryExpression
    with ImplicitCastInputTypes {

  private val matrixUDT = SQLUtils.newMatrixUDT()

  override def dataType: DataType =
        StructField("beta", DoubleType),
        StructField("standardError", DoubleType),
        StructField("pValue", DoubleType)))

  override def inputTypes: Seq[DataType] =
    Seq(ArrayType(DoubleType), ArrayType(DoubleType), matrixUDT)

  override def children: Seq[Expression] = Seq(genotypes, phenotypes, covariates)

  override protected def nullSafeEval(genotypes: Any, phenotypes: Any, covariates: Any): Any = {
    LinearRegressionExpr.doLinearRegression(genotypes, phenotypes, covariates)

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
      (genotypes, phenotypes, covariates) => {
         |${ev.value} = io.projectglow.sql.expressions.LinearRegressionExpr.doLinearRegression($genotypes, $phenotypes, $covariates);
Example 11
Source File: ShapeType.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba

import org.apache.spark.sql.types._
import org.apache.spark.sql.simba.spatial.Shape
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}

private[simba] class ShapeType extends UserDefinedType[Shape] {
  override def sqlType: DataType = ArrayType(ByteType, containsNull = false)

  override def serialize(s: Shape): Any = {
    new GenericArrayData(ShapeSerializer.serialize(s))

  override def userClass: Class[Shape] = classOf[Shape]

  override def deserialize(datum: Any): Shape = {
    datum match {
      case values: ArrayData =>

case object ShapeType extends ShapeType