org.apache.spark.sql.catalyst.util.ArrayData Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.util.ArrayData. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: ArrayType.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.types

import scala.math.Ordering

import org.json4s.JsonDSL._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.util.ArrayData


  override def defaultSize: Int = 100 * elementType.defaultSize

  override def simpleString: String = s"array<${elementType.simpleString}>"

  override def catalogString: String = s"array<${elementType.catalogString}>"

  override def sql: String = s"ARRAY<${elementType.sql}>"

  override private[spark] def asNullable: ArrayType =
    ArrayType(elementType.asNullable, containsNull = true)

  override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
    f(this) || elementType.existsRecursively(f)
  }

  @transient
  private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
    private[this] val elementOrdering: Ordering[Any] = elementType match {
      case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
      case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case other =>
        throw new IllegalArgumentException(s"Type $other does not support ordered operations")
    }

    def compare(x: ArrayData, y: ArrayData): Int = {
      val leftArray = x
      val rightArray = y
      val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
      var i = 0
      while (i < minLength) {
        val isNullLeft = leftArray.isNullAt(i)
        val isNullRight = rightArray.isNullAt(i)
        if (isNullLeft && isNullRight) {
          // Do nothing.
        } else if (isNullLeft) {
          return -1
        } else if (isNullRight) {
          return 1
        } else {
          val comp =
            elementOrdering.compare(
              leftArray.get(i, elementType),
              rightArray.get(i, elementType))
          if (comp != 0) {
            return comp
          }
        }
        i += 1
      }
      if (leftArray.numElements() < rightArray.numElements()) {
        return -1
      } else if (leftArray.numElements() > rightArray.numElements()) {
        return 1
      } else {
        return 0
      }
    }
  }
} 
Example 2
Source File: CardinalityHashFunctionTest.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package com.swoop.alchemy.spark.expressions.hll

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.scalatest.{Matchers, WordSpec}


class CardinalityHashFunctionTest extends WordSpec with Matchers {

  "Cardinality hash functions" should {
    "account for nulls" in {
      val a = UTF8String.fromString("a")

      allDistinct(Seq(
        null,
        Array.empty[Byte],
        Array.apply(1.toByte)
      ), BinaryType)

      allDistinct(Seq(
        null,
        UTF8String.fromString(""),
        a
      ), StringType)

      allDistinct(Seq(
        null,
        ArrayData.toArrayData(Array.empty),
        ArrayData.toArrayData(Array(null)),
        ArrayData.toArrayData(Array(null, null)),
        ArrayData.toArrayData(Array(a, null)),
        ArrayData.toArrayData(Array(null, a))
      ), ArrayType(StringType))


      allDistinct(Seq(
        null,
        ArrayBasedMapData(Map.empty),
        ArrayBasedMapData(Map(null.asInstanceOf[String] -> null))
      ), MapType(StringType, StringType))

      allDistinct(Seq(
        null,
        InternalRow(null),
        InternalRow(a)
      ), new StructType().add("foo", StringType))

      allDistinct(Seq(
        InternalRow(null, a),
        InternalRow(a, null)
      ), new StructType().add("foo", StringType).add("bar", StringType))
    }
  }

  def allDistinct(values: Seq[Any], dataType: DataType): Unit = {
    val hashed = values.map(x => CardinalityXxHash64Function.hash(x, dataType, 0))
    hashed.distinct.length should be(hashed.length)
  }

} 
Example 3
Source File: ArrayType.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.types

import scala.math.Ordering

import org.json4s.JsonDSL._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.util.ArrayData


  override def defaultSize: Int = 1 * elementType.defaultSize

  override def simpleString: String = s"array<${elementType.simpleString}>"

  override def catalogString: String = s"array<${elementType.catalogString}>"

  override def sql: String = s"ARRAY<${elementType.sql}>"

  override private[spark] def asNullable: ArrayType =
    ArrayType(elementType.asNullable, containsNull = true)

  override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
    f(this) || elementType.existsRecursively(f)
  }

  @transient
  private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
    private[this] val elementOrdering: Ordering[Any] = elementType match {
      case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
      case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case other =>
        throw new IllegalArgumentException(
          s"Type ${other.catalogString} does not support ordered operations")
    }

    def compare(x: ArrayData, y: ArrayData): Int = {
      val leftArray = x
      val rightArray = y
      val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
      var i = 0
      while (i < minLength) {
        val isNullLeft = leftArray.isNullAt(i)
        val isNullRight = rightArray.isNullAt(i)
        if (isNullLeft && isNullRight) {
          // Do nothing.
        } else if (isNullLeft) {
          return -1
        } else if (isNullRight) {
          return 1
        } else {
          val comp =
            elementOrdering.compare(
              leftArray.get(i, elementType),
              rightArray.get(i, elementType))
          if (comp != 0) {
            return comp
          }
        }
        i += 1
      }
      if (leftArray.numElements() < rightArray.numElements()) {
        return -1
      } else if (leftArray.numElements() > rightArray.numElements()) {
        return 1
      } else {
        return 0
      }
    }
  }
} 
Example 4
Source File: InternalRow.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


  def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
    case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
    case ByteType => (input, ordinal) => input.getByte(ordinal)
    case ShortType => (input, ordinal) => input.getShort(ordinal)
    case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
    case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
    case FloatType => (input, ordinal) => input.getFloat(ordinal)
    case DoubleType => (input, ordinal) => input.getDouble(ordinal)
    case StringType => (input, ordinal) => input.getUTF8String(ordinal)
    case BinaryType => (input, ordinal) => input.getBinary(ordinal)
    case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
    case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
    case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
    case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
    case _: MapType => (input, ordinal) => input.getMap(ordinal)
    case u: UserDefinedType[_] => getAccessor(u.sqlType)
    case _ => (input, ordinal) => input.get(ordinal, dataType)
  }
} 
Example 5
Source File: JavaConverter.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb.catalyst

import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.JavaConverters._

object JavaConverter {

    def convertRowValue(row: InternalRow, index: Int, elementType: DataType): Any = {
        elementType match {
            case ArrayType(innerType, _) => convertArray(row.getArray(index), innerType)
            case MapType(keyType, valueType, _) => convertMap(row.getMap(index), keyType, valueType)
            case StructType(fields) => convertStruct(row.getStruct(index, fields.length), fields)
            case StringType => row.getString(index)
            case _ => row.get(index, elementType)
        }
    }

    def convertArray(array: ArrayData, elementType: DataType): Any = {
        elementType match {
            case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(convertArray(_, innerType)).asJava
            case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(convertMap(_, keyType, valueType)).asJava
            case structType: StructType => array.toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)).asJava
            case StringType => convertStringArray(array).asJava
            case _ => array.toSeq[Any](elementType).asJava
        }
    }

    def convertMap(map: MapData, keyType: DataType, valueType: DataType): util.Map[String, Any] = {
        if (keyType != StringType) throw new IllegalArgumentException(
            s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
        val keys = convertStringArray(map.keyArray())
        val values = valueType match {
            case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(convertArray(_, innerType))
            case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(convertMap(_, innerKeyType, innerValueType))
            case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(convertStruct(_, structType.fields))
            case StringType => convertStringArray(map.valueArray())
            case _ => map.valueArray().toSeq[Any](valueType)
        }
        val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i)
        Map(kvPairs: _*).asJava
    }

    def convertStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = {
        val kvPairs = for (i <- 0 until row.numFields) yield
            if (row.isNullAt(i)) fields(i).name -> null
            else fields(i).name -> convertRowValue(row, i, fields(i).dataType)
        Map(kvPairs: _*).asJava
    }

    def convertStringArray(array: ArrayData): Seq[String] =
        array.toSeq[UTF8String](StringType).map(_.toString)

} 
Example 6
Source File: GenerateUnsafeProjectionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class GenerateUnsafeProjectionSuite extends SparkFunSuite {
  test("Test unsafe projection string access pattern") {
    val dataType = (new StructType).add("a", StringType)
    val exprs = BoundReference(0, dataType, nullable = true) :: Nil
    val projection = GenerateUnsafeProjection.generate(exprs)
    val result = projection.apply(InternalRow(AlwaysNull))
    assert(!result.isNullAt(0))
    assert(result.getStruct(0, 1).isNullAt(0))
  }
}

object AlwaysNull extends InternalRow {
  override def numFields: Int = 1
  override def setNullAt(i: Int): Unit = {}
  override def copy(): InternalRow = this
  override def anyNull: Boolean = true
  override def isNullAt(ordinal: Int): Boolean = true
  override def update(i: Int, value: Any): Unit = notSupported
  override def getBoolean(ordinal: Int): Boolean = notSupported
  override def getByte(ordinal: Int): Byte = notSupported
  override def getShort(ordinal: Int): Short = notSupported
  override def getInt(ordinal: Int): Int = notSupported
  override def getLong(ordinal: Int): Long = notSupported
  override def getFloat(ordinal: Int): Float = notSupported
  override def getDouble(ordinal: Int): Double = notSupported
  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
  override def getUTF8String(ordinal: Int): UTF8String = notSupported
  override def getBinary(ordinal: Int): Array[Byte] = notSupported
  override def getInterval(ordinal: Int): CalendarInterval = notSupported
  override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
  override def getArray(ordinal: Int): ArrayData = notSupported
  override def getMap(ordinal: Int): MapData = notSupported
  override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
  private def notSupported: Nothing = throw new UnsupportedOperationException
} 
Example 7
Source File: JacksonGenerator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.json

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils}

import scala.collection.Map

import com.fasterxml.jackson.core._

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

private[sql] object JacksonGenerator {
  
  def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = {
    def valWriter: (DataType, Any) => Unit = {
      case (_, null) | (NullType, _) => gen.writeNull()
      case (StringType, v) => gen.writeString(v.toString)
      case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString)
      case (IntegerType, v: Int) => gen.writeNumber(v)
      case (ShortType, v: Short) => gen.writeNumber(v)
      case (FloatType, v: Float) => gen.writeNumber(v)
      case (DoubleType, v: Double) => gen.writeNumber(v)
      case (LongType, v: Long) => gen.writeNumber(v)
      case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal)
      case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
      case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
      case (BooleanType, v: Boolean) => gen.writeBoolean(v)
      case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString)
      // For UDT values, they should be in the SQL type's corresponding value type.
      // We should not see values in the user-defined class at here.
      // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is
      // an ArrayData at here, instead of a Vector.
      case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)

      case (ArrayType(ty, _), v: ArrayData) =>
        gen.writeStartArray()
        v.foreach(ty, (_, value) => valWriter(ty, value))
        gen.writeEndArray()

      case (MapType(kt, vt, _), v: MapData) =>
        gen.writeStartObject()
        v.foreach(kt, vt, { (k, v) =>
          gen.writeFieldName(k.toString)
          valWriter(vt, v)
        })
        gen.writeEndObject()

      case (StructType(ty), v: InternalRow) =>
        gen.writeStartObject()
        var i = 0
        while (i < ty.length) {
          val field = ty(i)
          val value = v.get(i, field.dataType)
          if (value != null) {
            gen.writeFieldName(field.name)
            valWriter(field.dataType, value)
          }
          i += 1
        }
        gen.writeEndObject()

      case (dt, v) =>
        sys.error(
          s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.")
    }

    valWriter(rowSchema, row)
  }
} 
Example 8
Source File: ExamplePointUDT.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.types._


private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

  override def sqlType: DataType = ArrayType(DoubleType, false)

  override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

  override def serialize(obj: Any): GenericArrayData = {
    obj match {
      case p: ExamplePoint =>
        val output = new Array[Any](2)
        output(0) = p.x
        output(1) = p.y
        new GenericArrayData(output)
    }
  }

  override def deserialize(datum: Any): ExamplePoint = {
    datum match {
      case values: ArrayData =>
        new ExamplePoint(values.getDouble(0), values.getDouble(1))
    }
  }

  override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]

  private[spark] override def asNullable: ExamplePointUDT = this
} 
Example 9
Source File: MeanSubstitute.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.sql.dsl._
import io.projectglow.sql.util.RewriteAfterResolution


case class MeanSubstitute(array: Expression, missingValue: Expression)
    extends RewriteAfterResolution {

  override def children: Seq[Expression] = Seq(array, missingValue)

  def this(array: Expression) = {
    this(array, Literal(-1))
  }

  private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType

  // A value is considered missing if it is NaN, null or equal to the missing value parameter
  def isMissing(arrayElement: Expression): Predicate =
    IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue

  def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = {
    val sumName = Literal(UTF8String.fromString("sum"), StringType)
    val countName = Literal(UTF8String.fromString("count"), StringType)
    namedStruct(sumName, sumValue, countName, countValue)
  }

  // Update sum and count with array element if not missing
  def updateSumAndCountConditionally(
      stateStruct: Expression,
      arrayElement: Expression): Expression = {
    If(
      isMissing(arrayElement),
      // If value is missing, do not update sum and count
      stateStruct,
      // If value is not missing, add to sum and increment count
      createNamedStruct(
        stateStruct.getField("sum") + arrayElement,
        stateStruct.getField("count") + 1)
    )
  }

  // Calculate mean for imputation
  def calculateMean(stateStruct: Expression): Expression = {
    If(
      stateStruct.getField("count") > 0,
      // If non-missing values were found, calculate the average
      stateStruct.getField("sum") / stateStruct.getField("count"),
      // If all values were missing, substitute with missing value
      missingValue
    )
  }

  lazy val arrayMean: Expression = {
    // Sum and count of non-missing values
    array.aggregate(
      createNamedStruct(Literal(0d), Literal(0L)),
      updateSumAndCountConditionally,
      calculateMean
    )
  }

  def substituteWithMean(arrayElement: Expression): Expression = {
    If(isMissing(arrayElement), arrayMean, arrayElement)
  }

  override def rewrite: Expression = {
    if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.")
    }

    if (!missingValue.dataType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.")
    }

    // Replace missing values with the provided strategy
    array.arrayTransform(substituteWithMean(_))
  }
} 
Example 10
Source File: LinearRegressionExpr.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import breeze.linalg.DenseVector
import org.apache.spark.TaskContext
import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, TernaryExpression}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._

object LinearRegressionExpr {
  private val matrixUDT = SQLUtils.newMatrixUDT()
  private val state = new ThreadLocal[CovariateQRContext]

  def doLinearRegression(genotypes: Any, phenotypes: Any, covariates: Any): InternalRow = {

    if (state.get() == null) {
      // Save the QR factorization of the covariate matrix since it's the same for every row
      state.set(CovariateQRContext.computeQR(matrixUDT.deserialize(covariates).toDense))
      TaskContext.get().addTaskCompletionListener[Unit](_ => state.remove())
    }

    LinearRegressionGwas.linearRegressionGwas(
      new DenseVector[Double](genotypes.asInstanceOf[ArrayData].toDoubleArray()),
      new DenseVector[Double](phenotypes.asInstanceOf[ArrayData].toDoubleArray()),
      state.get()
    )
  }
}

case class LinearRegressionExpr(
    genotypes: Expression,
    phenotypes: Expression,
    covariates: Expression)
    extends TernaryExpression
    with ImplicitCastInputTypes {

  private val matrixUDT = SQLUtils.newMatrixUDT()

  override def dataType: DataType =
    StructType(
      Seq(
        StructField("beta", DoubleType),
        StructField("standardError", DoubleType),
        StructField("pValue", DoubleType)))

  override def inputTypes: Seq[DataType] =
    Seq(ArrayType(DoubleType), ArrayType(DoubleType), matrixUDT)

  override def children: Seq[Expression] = Seq(genotypes, phenotypes, covariates)

  override protected def nullSafeEval(genotypes: Any, phenotypes: Any, covariates: Any): Any = {
    LinearRegressionExpr.doLinearRegression(genotypes, phenotypes, covariates)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(
      ctx,
      ev,
      (genotypes, phenotypes, covariates) => {
        s"""
         |${ev.value} = io.projectglow.sql.expressions.LinearRegressionExpr.doLinearRegression($genotypes, $phenotypes, $covariates);
       """.stripMargin
      }
    )
  }
} 
Example 11
Source File: ShapeType.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba

import org.apache.spark.sql.types._
import org.apache.spark.sql.simba.spatial.Shape
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}


private[simba] class ShapeType extends UserDefinedType[Shape] {
  override def sqlType: DataType = ArrayType(ByteType, containsNull = false)

  override def serialize(s: Shape): Any = {
    new GenericArrayData(ShapeSerializer.serialize(s))
  }

  override def userClass: Class[Shape] = classOf[Shape]

  override def deserialize(datum: Any): Shape = {
    datum match {
      case values: ArrayData =>
        ShapeSerializer.deserialize(values.toByteArray)
    }
  }
}

case object ShapeType extends ShapeType