Example 1
Source File: AvroDataToCatalyst.scala    From spark-schema-registry   with Apache License 2.0
package com.hortonworks.spark.registry.avro


import com.hortonworks.registries.schemaregistry.{SchemaVersionInfo, SchemaVersionKey}
import com.hortonworks.registries.schemaregistry.client.SchemaRegistryClient
import com.hortonworks.registries.schemaregistry.serdes.avro.AvroSnapshotDeserializer
import org.apache.avro.Schema
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}

import scala.collection.JavaConverters._

case class AvroDataToCatalyst(child: Expression, schemaName: String, version: Option[Int], config: Map[String, Object])
  extends UnaryExpression with ExpectsInputTypes {

  override def inputTypes = Seq(BinaryType)

  @transient private lazy val srDeser: AvroSnapshotDeserializer = {
    val obj = new AvroSnapshotDeserializer()

  @transient private lazy val srSchema = fetchSchemaVersionInfo(schemaName, version)

  @transient private lazy val avroSchema = new Schema.Parser().parse(srSchema.getSchemaText)

  override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType

  @transient private lazy val avroDeser= new AvroDeserializer(avroSchema, dataType)

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any = {
    val binary = input.asInstanceOf[Array[Byte]]
    val row = avroDeser.deserialize(srDeser.deserialize(new ByteArrayInputStream(binary), srSchema.getVersion))
    val result = row match {
      case r: InternalRow => r.copy()
      case _ => row

  override def simpleString: String = {
    s"from_sr(${child.sql}, ${dataType.simpleString})"

  override def sql: String = {
    s"from_sr(${child.sql}, ${dataType.catalogString})"

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val expr = ctx.addReferenceObj("this", this)
    defineCodeGen(ctx, ev, input =>

  private def fetchSchemaVersionInfo(schemaName: String, version: Option[Int]): SchemaVersionInfo = {
    val srClient = new SchemaRegistryClient(config.asJava) => srClient.getSchemaVersionInfo(new SchemaVersionKey(schemaName, v)))

Example 2
Source File: TimestampCast.scala    From flint   with Apache License 2.0
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.codegen.{ CodegenContext, ExprCode, CodeGenerator, JavaCode, Block }
import org.apache.spark.sql.catalyst.expressions.{ Expression, NullIntolerant, UnaryExpression }
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{ DataType, LongType, TimestampType }

case class TimestampToNanos(child: Expression) extends TimestampCast {
  val dataType: DataType = LongType
  protected def cast(childPrim: String): String =
    s"$childPrim * 1000L"
  override protected def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Long] * 1000L

case class NanosToTimestamp(child: Expression) extends TimestampCast {
  val dataType: DataType = TimestampType
  protected def cast(childPrim: String): String =
    s"$childPrim / 1000L"
  override protected def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Long] / 1000L

object TimestampToNanos {
  private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String,
    resultPrim: String, resultNull: String, resultType: DataType): Block = {
      boolean $resultNull = $childNull;
      ${CodeGenerator.javaType(resultType)} $resultPrim = ${CodeGenerator.defaultValue(resultType)};
      if (!${childNull}) {
        $resultPrim = (long) ${cast(childPrim)};

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    ev.copy(code = eval.code +
      castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType))
Example 3
Source File: XmlDataToCatalyst.scala    From spark-xml   with Apache License 2.0
package com.databricks.spark.xml

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlDataToCatalyst(
    child: Expression,
    schema: DataType,
    options: XmlOptions)
  extends UnaryExpression with CodegenFallback with ExpectsInputTypes {

  override lazy val dataType: DataType = schema

  lazy val rowSchema: StructType = schema match {
    case st: StructType => st
    case ArrayType(st: StructType, _) => st

  override def nullSafeEval(xml: Any): Any = xml match {
    case string: UTF8String =>
        StaxXmlParser.parseColumn(string.toString, rowSchema, options))
    case string: String =>
      StaxXmlParser.parseColumn(string, rowSchema, options)
    case arr: GenericArrayData =>
      CatalystTypeConverters.convertToCatalyst( => StaxXmlParser.parseColumn(s.toString, rowSchema, options)))
    case arr: Array[_] => => StaxXmlParser.parseColumn(s.toString, rowSchema, options))
    case _ => null

  override def inputTypes: Seq[DataType] = schema match {
    case _: StructType => Seq(StringType)
    case ArrayType(_: StructType, _) => Seq(ArrayType(StringType))
Example 4
Source File: FunctionBuilders.scala    From HANAVora-Extensions   with Apache License 2.0
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Expression, BinaryExpression, UnaryExpression}

import scala.reflect.ClassTag

object FunctionBuilders {

  type ExpressionBuilder = Seq[Expression] => Expression

  def expression[T <: Expression](arity: Int)(implicit tag: ClassTag[T]): ExpressionBuilder = {
    val argTypes = (1 to arity).map(x => classOf[Expression])
    val constructor = tag.runtimeClass.getDeclaredConstructor(argTypes: _*)
    (expressions: Seq[Expression]) => {
      if (expressions.size != arity) {
        throw new IllegalArgumentException(
          s"Invalid number of arguments: ${expressions.size} (must be equal to $arity)"
      constructor.newInstance(expressions: _*).asInstanceOf[Expression]

  def unaryExpression[T <: UnaryExpression](implicit tag: ClassTag[T]): ExpressionBuilder =

  def binaryExpression[T <: BinaryExpression](implicit tag: ClassTag[T]): ExpressionBuilder =

  def reverse(expressionBuilder: ExpressionBuilder): ExpressionBuilder =
    (expressions: Seq[Expression]) => {

Example 5
Source File: CheckDeltaInvariant.scala    From delta   with Apache License 2.0

import{ArbitraryExpression, NotNull}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, ExprCode, JavaCode, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, NullType}

case class CheckDeltaInvariant(
    child: Expression,
    invariant: Invariant) extends UnaryExpression with NonSQLExpression {

  override def dataType: DataType = NullType
  override def foldable: Boolean = false
  override def nullable: Boolean = true

  override def flatArguments: Iterator[Any] = Iterator(child)

  private def assertRule(input: InternalRow): Unit = invariant.rule match {
    case NotNull if child.eval(input) == null =>
      throw InvariantViolationException(invariant, "")
    case ArbitraryExpression(expr) =>
      val resolvedExpr = expr.transform {
        case _: UnresolvedAttribute => child
      val result = resolvedExpr.eval(input)
      if (result == null || result == false) {
        throw InvariantViolationException(
          invariant, s"Value ${child.eval(input)} violates requirement.")

  override def eval(input: InternalRow): Any = {

  private def generateNotNullCode(ctx: CodegenContext): Block = {
    val childGen = child.genCode(ctx)
    val invariantField = ctx.addReferenceObj("errMsg", invariant)
       |if (${childGen.isNull}) {
       |  throw
       |    $invariantField, "");

  private def generateExpressionValidationCode(expr: Expression, ctx: CodegenContext): Block = {
    val resolvedExpr = expr.transform {
      case _: UnresolvedAttribute => child
    val elementValue = child.genCode(ctx)
    val childGen = resolvedExpr.genCode(ctx)
    val invariantField = ctx.addReferenceObj("errMsg", invariant)
    val eValue = ctx.freshName("elementResult")
       |if (${childGen.isNull} || ${childGen.value} == false) {
       |  Object $eValue = "null";
       |  if (!${elementValue.isNull}) {
       |    $eValue = (Object) ${elementValue.value};
       |  }
       |  throw
       |     $invariantField, "Value " + $eValue + " violates requirement.");

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val code = invariant.rule match {
      case NotNull => generateNotNullCode(ctx)
      case ArbitraryExpression(expr) => generateExpressionValidationCode(expr, ctx)
    ev.copy(code = code, isNull = TrueLiteral, value = JavaCode.literal("null", NullType))
Example 6
Source File: EncodeLong.scala    From morpheus   with Apache License 2.0
package org.opencypher.morpheus.impl.expressions

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant, UnaryExpression}
import org.apache.spark.sql.types.{BinaryType, DataType, LongType}
import org.opencypher.morpheus.api.value.MorpheusElement._

case class EncodeLong(child: Expression) extends UnaryExpression with NullIntolerant with ExpectsInputTypes {

  override val dataType: DataType = BinaryType

  override val inputTypes: Seq[LongType] = Seq(LongType)

  override protected def nullSafeEval(input: Any): Any =

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
    defineCodeGen(ctx, ev, c => s"(byte[])(${EncodeLong.getClass.getName.dropRight(1)}.encodeLong($c))")

object EncodeLong {

  private final val moreBytesBitMask: Long = Integer.parseInt("10000000", 2)
  private final val varLength7BitMask: Long = Integer.parseInt("01111111", 2)
  private final val otherBitsMask = ~varLength7BitMask
  private final val maxBytesForLongVarEncoding = 10

  // Same encoding as as Base 128 Varints @
  final def encodeLong(l: Long): Array[Byte] = {
    val tempResult = new Array[Byte](maxBytesForLongVarEncoding)

    var remainder = l
    var index = 0

    while ((remainder & otherBitsMask) != 0) {
      tempResult(index) = ((remainder & varLength7BitMask) | moreBytesBitMask).toByte
      remainder >>>= 7
      index += 1
    tempResult(index) = remainder.toByte

    val result = new Array[Byte](index + 1)
    System.arraycopy(tempResult, 0, result, 0, index + 1)

  // Same encoding as as Base 128 Varints @
  final def decodeLong(input: Array[Byte]): Long = {
    assert(input.nonEmpty, "`decodeLong` requires a non-empty array as its input")
    var index = 0
    var currentByte = input(index)
    var decoded = currentByte & varLength7BitMask
    var nextLeftShift = 7

    while ((currentByte & moreBytesBitMask) != 0) {
      index += 1
      currentByte = input(index)
      decoded |= (currentByte & varLength7BitMask) << nextLeftShift
      nextLeftShift += 7
    assert(index == input.length - 1,
      s"`decodeLong` received an input array ${input.toSeq.toHex} with extra bytes that could not be decoded.")

  implicit class ColumnLongOps(val c: Column) extends AnyVal {

    def encodeLongAsMorpheusId(name: String): Column =

    def encodeLongAsMorpheusId: Column = new Column(EncodeLong(c.expr))


Example 7
Source File: CatalystDataToAvro.scala    From spark-schema-registry   with Apache License 2.0
package com.hortonworks.spark.registry.avro

import com.hortonworks.registries.schemaregistry.{SchemaCompatibility, SchemaMetadata}
import com.hortonworks.registries.schemaregistry.avro.AvroSchemaProvider
import com.hortonworks.registries.schemaregistry.client.SchemaRegistryClient
import com.hortonworks.registries.schemaregistry.serdes.avro.AvroSnapshotSerializer
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}

import scala.collection.JavaConverters._

case class CatalystDataToAvro(
    child: Expression,
    schemaName: String,
    recordName: String,
    nameSpace: String,
    config: Map[String, Object]
    ) extends UnaryExpression {

  override def dataType: DataType = BinaryType

  private val topLevelRecordName = if (recordName == "") schemaName else recordName

  @transient private lazy val avroType =
    SchemaConverters.toAvroType(child.dataType, child.nullable, topLevelRecordName, nameSpace)

  @transient private lazy val avroSer =
    new AvroSerializer(child.dataType, avroType, child.nullable)

  @transient private lazy val srSer: AvroSnapshotSerializer = {
    val obj = new AvroSnapshotSerializer()

  @transient private lazy val srClient = new SchemaRegistryClient(config.asJava)

  @transient private lazy val schemaMetadata = {
    var schemaMetadataInfo = srClient.getSchemaMetadataInfo(schemaName)
    if (schemaMetadataInfo == null) {
      val generatedSchemaMetadata = new SchemaMetadata.Builder(schemaName).
        .schemaGroup("Autogenerated group")
        .description("Autogenerated schema")
    } else {

  override def nullSafeEval(input: Any): Any = {
    val avroData = avroSer.serialize(input)
    srSer.serialize(avroData.asInstanceOf[Object], schemaMetadata)

  override def simpleString: String = {
    s"to_sr(${child.sql}, ${child.dataType.simpleString})"

  override def sql: String = {
    s"to_sr(${child.sql}, ${child.dataType.catalogString})"

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val expr = ctx.addReferenceObj("this", this)
    defineCodeGen(ctx, ev, input =>
      s"(byte[]) $expr.nullSafeEval($input)")
Example 8
Source File: SqlExtensionProviderSuite.scala    From glow   with Apache License 2.0
package io.projectglow.sql

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Literal, UnaryExpression}
import org.apache.spark.sql.types.{DataType, IntegerType}

import io.projectglow.GlowSuite

class SqlExtensionProviderSuite extends GlowSuite {
  override def beforeAll(): Unit = {

  private lazy val sess = spark
  test("one arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("one_arg_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {

    intercept[AnalysisException] {
      spark.range(1).selectExpr("one_arg_test(id, id)").collect()

  test("two arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("two_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {

    intercept[AnalysisException] {
      spark.range(1).selectExpr("two_arg_test(id, id, id)").collect()

  test("var args function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("var_args_test(id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id, id, id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {

  test("can call optional arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("optional_arg_test(id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("optional_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {

    intercept[AnalysisException] {
      spark.range(1).selectExpr("optional_arg_test(id, id, id)").collect()

trait TestExpr extends Expression with CodegenFallback {
  override def dataType: DataType = IntegerType

  override def nullable: Boolean = true

  override def eval(input: InternalRow): Any = 1

case class OneArgExpr(child: Expression) extends UnaryExpression with TestExpr
case class TwoArgExpr(left: Expression, right: Expression) extends BinaryExpression with TestExpr
case class VarArgsExpr(arg: Expression, varArgs: Seq[Expression]) extends TestExpr {
  override def children: Seq[Expression] = arg +: varArgs
case class OptionalArgExpr(required: Expression, optional: Expression) extends TestExpr {
  def this(required: Expression) = this(required, Literal(1))
  override def children: Seq[Expression] = Seq(required, optional)