org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: EquivalentExpressions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback


  def debugString(all: Boolean = false): String = {
    val sb: mutable.StringBuilder = new StringBuilder()
    sb.append("Equivalent expressions:\n")
    equivalenceMap.foreach { case (k, v) =>
      if (all || v.length > 1) {
        sb.append("  " + v.mkString(", ")).append("\n")
      }
    }
    sb.toString()
  }
} 
Example 2
Source File: XmlDataToCatalyst.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlDataToCatalyst(
    child: Expression,
    schema: DataType,
    options: XmlOptions)
  extends UnaryExpression with CodegenFallback with ExpectsInputTypes {

  override lazy val dataType: DataType = schema

  @transient
  lazy val rowSchema: StructType = schema match {
    case st: StructType => st
    case ArrayType(st: StructType, _) => st
  }

  override def nullSafeEval(xml: Any): Any = xml match {
    case string: UTF8String =>
      CatalystTypeConverters.convertToCatalyst(
        StaxXmlParser.parseColumn(string.toString, rowSchema, options))
    case string: String =>
      StaxXmlParser.parseColumn(string, rowSchema, options)
    case arr: GenericArrayData =>
      CatalystTypeConverters.convertToCatalyst(
        arr.array.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options)))
    case arr: Array[_] =>
      arr.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options))
    case _ => null
  }

  override def inputTypes: Seq[DataType] = schema match {
    case _: StructType => Seq(StringType)
    case ArrayType(_: StructType, _) => Seq(ArrayType(StringType))
  }
} 
Example 3
Source File: EquivalentExpressions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable


  def debugString(all: Boolean = false): String = {
    val sb: mutable.StringBuilder = new StringBuilder()
    sb.append("Equivalent expressions:\n")
    equivalenceMap.foreach { case (k, v) =>
      if (all || v.length > 1) {
        sb.append("  " + v.mkString(", ")).append("\n")
      }
    }
    sb.toString()
  }
} 
Example 4
Source File: ERPCurrencyConversionExpression.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.currency.erp

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.currency.CurrencyConversionException
import org.apache.spark.sql.currency.erp.ERPConversionLoader.RConversionOptionsCurried
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import scala.util.control.NonFatal



case class ERPCurrencyConversionExpression(
    conversionFunction: RConversionOptionsCurried,
    children: Seq[Expression])
  extends Expression
  with ImplicitCastInputTypes
  with CodegenFallback {

  protected val CLIENT_INDEX = 0
  protected val CONVERSION_TYPE_INDEX = 1
  protected val AMOUNT_INDEX = 2
  protected val FROM_INDEX = 3
  protected val TO_INDEX = 4
  protected val DATE_INDEX = 5
  protected val NUM_ARGS = 6

  protected val errorMessage = "Currency conversion library encountered an internal error"


  override def eval(input: InternalRow): Any = {
    val inputArguments = children.map(_.eval(input))

    require(inputArguments.length == NUM_ARGS, "wrong number of arguments")

    // parse arguments
    val client = Option(inputArguments(CLIENT_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val conversionType =
      Option(inputArguments(CONVERSION_TYPE_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val amount = Option(inputArguments(AMOUNT_INDEX).asInstanceOf[Decimal].toJavaBigDecimal)
    val sourceCurrency =
      Option(inputArguments(FROM_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val targetCurrency = Option(inputArguments(TO_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val date = Option(inputArguments(DATE_INDEX).asInstanceOf[UTF8String]).map(_.toString)

    // perform conversion
    val conversion =
      conversionFunction(client, conversionType, sourceCurrency, targetCurrency, date)
    val resultTry = conversion(amount)

    // If 'resultTry' holds a 'Failure', we have to propagate it because potential failure
    // handling already took place. We just wrap it in case it is a cryptic error.
    resultTry.recover {
      case NonFatal(err) => throw new CurrencyConversionException(errorMessage, err)
    }.get.map(Decimal.apply).orNull
  }

  override def dataType: DataType = DecimalType.forType(DoubleType)

  override def nullable: Boolean = true

  override def inputTypes: Seq[AbstractDataType] =
    Seq(StringType, StringType, DecimalType, StringType, StringType, StringType)

  def inputNames: Seq[String] =
    Seq("client", "conversion_type", "amount", "source", "target", "date")

  def getChild(name: String): Option[Expression] = {
    inputNames.zip(children).find { case (n, _) => name == n }.map(_._2)
  }
} 
Example 5
Source File: AnnotationFilter.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{UnresolvedException, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.{InternalRow, trees}
import org.apache.spark.sql.types._


case class AnnotationFilter(child: Expression)(
  val filters: Set[String] = Set.empty,
  val exprId: ExprId = NamedExpression.newExprId)
  extends UnaryExpression
  with NamedExpression
  with CodegenFallback {

  override def name: String = child match {
    case e:NamedExpression => e.name
    case _ => throw new UnresolvedException(this, "name of AnnotationFilter with non-named child")
  }

  override lazy val resolved = childrenResolved

  override def toAttribute: Attribute = {
    if (resolved) {
      child.transform ({
        case a:Alias => a.copy(a.child, a.name)(a.exprId, qualifiers = a.qualifiers,
          explicitMetadata = Some(MetadataAccessor.filterMetadata(a.metadata, filters)))
        case a:AttributeReference =>
          a.copy(a.name, a.dataType, a.nullable,
            metadata = MetadataAccessor.filterMetadata(a.metadata, filters))(a.exprId, a.qualifiers)
        case p => p
      }) match {
        case e: NamedExpression => e.toAttribute
        case _ => throw new UnresolvedException(this, "toAttribute of AnnotationFilter with " +
          "no-named child")
      }
    } else {
      UnresolvedAttribute(name)
    }
  }

  override def equals(other: Any): Boolean = other match {
    case aa: AnnotationFilter => child == aa.child && filters == aa.filters &&
      exprId == aa.exprId
    case _ => false
  }

  // scalastyle:off magic.number
  override def hashCode:Int = {
    List[Int](child.hashCode, filters.hashCode, exprId.hashCode)
      .foldLeft(17)((l, r) => 31 * l + r)
  }

  override def metadata: Metadata = {
    child match {
      case named: NamedExpression => MetadataAccessor.filterMetadata(named.metadata, filters)
      case _ => Metadata.empty
    }
  }

  override def qualifiers: Seq[String] = Nil

  override def eval(input: InternalRow): Any = child.eval(input)

  override def nullable: Boolean = child.nullable

  override def dataType: DataType = child.dataType

  override protected final def otherCopyArgs: Seq[AnyRef] = filters :: exprId :: Nil
} 
Example 6
Source File: stringExpressions.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.unsafe.types._
import org.apache.spark.sql.types._


case class Replace(se: Expression, fe: Expression, pe: Expression)
  extends TernaryExpression
  with ImplicitCastInputTypes with CodegenFallback {

  override def inputTypes: Seq[AbstractDataType] = Seq.fill(3)(StringType)

  override def eval(input: InternalRow): Any = {
    val s = se.eval(input).asInstanceOf[UTF8String]
    val f = fe.eval(input).asInstanceOf[UTF8String]
    val p = pe.eval(input).asInstanceOf[UTF8String]
    (s, f, p) match {
      case (null, _, _) | (_, null, _) | (null, null, _) => null
      case (stre, strf, null) =>
        UTF8String.fromString(stre.toString()
          .replaceAllLiterally(strf.toString(), ""))
      case (stre, strf, strp) =>
        UTF8String.fromString(stre.toString()
          .replaceAllLiterally(strf.toString(), strp.toString()))
      case _ =>
        sys.error(s"Unexpected input")
    }
  }

  override def nullable: Boolean = se.nullable

  override def dataType: DataType = StringType

  override def children: Seq[Expression] = se :: fe :: pe :: Nil
} 
Example 7
Source File: dateExpressions.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._


case class AddYears(date: Expression, years: Expression)
  extends BinaryExpression
  with ImplicitCastInputTypes with CodegenFallback {

  override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

  override def nullSafeEval(d: Any, y: Any): Any = {
    DateTimeUtils.dateAddMonths(
      d.asInstanceOf[DateTimeUtils.SQLDate], y.asInstanceOf[Int] * 12
    )
  }

  override def left: Expression = date
  override def right: Expression = years
  override def dataType: DataType = DateType
} 
Example 8
Source File: AddSeconds.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util.DateTimeUtils


case class AddSeconds(timestamp: Expression, seconds: Expression)
  extends BinaryExpression
  with ImplicitCastInputTypes
  with CodegenFallback {

  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, IntegerType)

  override def nullSafeEval(microseconds: Any, seconds: Any): Any = {
    microseconds.asInstanceOf[DateTimeUtils.SQLTimestamp] +
      (seconds.asInstanceOf[Int] * DateTimeUtils.MICROS_PER_SECOND)
  }

  override def left: Expression = timestamp
  override def right: Expression = seconds
  override def dataType: DataType = TimestampType
} 
Example 9
Source File: serdes.scala    From magellan   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import magellan._
import magellan.catalyst.MagellanExpression
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._


case class MagellanSerializer(
    override val child: Expression,
    _dataType: DataType)
  extends UnaryExpression
  with MagellanExpression
  with CodegenFallback
  with NonSQLExpression {

  override def nullable: Boolean = false

  override protected def nullSafeEval(input: Any): Any = {
    val shape = input.asInstanceOf[Shape]
    serialize(shape)
  }

  override def dataType: DataType = _dataType
}

case class MagellanDeserializer(
    override val child: Expression, klass: Class[_ <: Shape])
  extends UnaryExpression
  with MagellanExpression
  with CodegenFallback
  with NonSQLExpression {

  override def nullable: Boolean = false

  override protected def nullSafeEval(input: Any): Any = {
    newInstance(input.asInstanceOf[InternalRow])
  }

  override def dataType: DataType = ObjectType(klass)
} 
Example 10
Source File: SqlExtensionProviderSuite.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Literal, UnaryExpression}
import org.apache.spark.sql.types.{DataType, IntegerType}

import io.projectglow.GlowSuite

class SqlExtensionProviderSuite extends GlowSuite {
  override def beforeAll(): Unit = {
    super.beforeAll()
    SqlExtensionProvider.registerFunctions(
      spark.sessionState.conf,
      spark.sessionState.functionRegistry,
      "test-functions.yml")
  }

  private lazy val sess = spark
  test("one arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("one_arg_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("one_arg_test()").collect()
    }

    intercept[AnalysisException] {
      spark.range(1).selectExpr("one_arg_test(id, id)").collect()
    }
  }

  test("two arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("two_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("two_arg_test(id)").collect()
    }

    intercept[AnalysisException] {
      spark.range(1).selectExpr("two_arg_test(id, id, id)").collect()
    }
  }

  test("var args function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("var_args_test(id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id, id, id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("var_args_test()").collect()
    }
  }

  test("can call optional arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("optional_arg_test(id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("optional_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("optional_arg_test()").collect()
    }

    intercept[AnalysisException] {
      spark.range(1).selectExpr("optional_arg_test(id, id, id)").collect()
    }
  }
}

trait TestExpr extends Expression with CodegenFallback {
  override def dataType: DataType = IntegerType

  override def nullable: Boolean = true

  override def eval(input: InternalRow): Any = 1
}

case class OneArgExpr(child: Expression) extends UnaryExpression with TestExpr
case class TwoArgExpr(left: Expression, right: Expression) extends BinaryExpression with TestExpr
case class VarArgsExpr(arg: Expression, varArgs: Seq[Expression]) extends TestExpr {
  override def children: Seq[Expression] = arg +: varArgs
}
case class OptionalArgExpr(required: Expression, optional: Expression) extends TestExpr {
  def this(required: Expression) = this(required, Literal(1))
  override def children: Seq[Expression] = Seq(required, optional)
} 
Example 11
Source File: InRange.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.expression

import org.apache.spark.sql.simba.{ShapeSerializer, ShapeType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Predicate}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.simba.spatial.{MBR, Point, Shape}
import org.apache.spark.sql.simba.util.ShapeUtils
import org.apache.spark.sql.catalyst.util.GenericArrayData


case class InRange(shape: Expression, range_low: Expression, range_high: Expression)
  extends Predicate with CodegenFallback{
  override def nullable: Boolean = false

  override def eval(input: InternalRow): Any = {
    val eval_shape = ShapeUtils.getShape(shape, input)
    val eval_low = range_low.asInstanceOf[Literal].value.asInstanceOf[Point]
    val eval_high = range_high.asInstanceOf[Literal].value.asInstanceOf[Point]
    require(eval_shape.dimensions == eval_low.dimensions && eval_shape.dimensions == eval_high.dimensions)
    val mbr = MBR(eval_low, eval_high)
    mbr.intersects(eval_shape)
  }

  override def toString: String = s" **($shape) IN Rectangle ($range_low) - ($range_high)**  "

  override def children: Seq[Expression] = Seq(shape, range_low, range_high)
} 
Example 12
Source File: BasicCurrencyConversionExpression.scala    From HANAVora-Extensions   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.sql.currency.basic

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


case class BasicCurrencyConversionExpression(
    conversion: BasicCurrencyConversion,
    children: Seq[Expression])
  extends Expression
  with ImplicitCastInputTypes
  with CodegenFallback {

  protected val AMOUNT_INDEX = 0
  protected val FROM_INDEX = 1
  protected val TO_INDEX = 2
  protected val DATE_INDEX = 3
  protected val NUM_ARGS = 4

  override def eval(input: InternalRow): Any = {
    val inputArguments = children.map(_.eval(input))

    require(inputArguments.length == NUM_ARGS, "wrong number of arguments")

    val sourceCurrency =
      Option(inputArguments(FROM_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val targetCurrency = Option(inputArguments(TO_INDEX).asInstanceOf[UTF8String]).map(_.toString)
    val amount = Option(inputArguments(AMOUNT_INDEX).asInstanceOf[Decimal].toJavaBigDecimal)
    val date = Option(inputArguments(DATE_INDEX).asInstanceOf[UTF8String]).map(_.toString)

    (amount, sourceCurrency, targetCurrency, date) match {
      case (Some(a), Some(s), Some(t), Some(d)) => nullSafeEval(a, s, t, d)
      case _ => null
    }
  }

  def nullSafeEval(amount: java.math.BigDecimal,
                   sourceCurrency: String,
                   targetCurrency: String,
                   date: String): Any = {
    conversion.convert(amount, sourceCurrency, targetCurrency, date)
      .get
      .map(Decimal.apply)
      .orNull
  }

  override def dataType: DataType = DecimalType.forType(DoubleType)

  override def nullable: Boolean = true

  // TODO(MD, CS): use DateType but support date string
  override def inputTypes: Seq[AbstractDataType] =
    Seq(DecimalType, StringType, StringType, StringType)
}