org.apache.spark.sql.catalyst.expressions.codegen.ExprCode Scala Examples
The following examples show how to use org.apache.spark.sql.catalyst.expressions.codegen.ExprCode.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: AvroDataToCatalyst.scala From spark-schema-registry with Apache License 2.0 | 6 votes |
package com.hortonworks.spark.registry.avro import java.io.ByteArrayInputStream import com.hortonworks.registries.schemaregistry.{SchemaVersionInfo, SchemaVersionKey} import com.hortonworks.registries.schemaregistry.client.SchemaRegistryClient import com.hortonworks.registries.schemaregistry.serdes.avro.AvroSnapshotDeserializer import org.apache.avro.Schema import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{BinaryType, DataType} import scala.collection.JavaConverters._ case class AvroDataToCatalyst(child: Expression, schemaName: String, version: Option[Int], config: Map[String, Object]) extends UnaryExpression with ExpectsInputTypes { override def inputTypes = Seq(BinaryType) @transient private lazy val srDeser: AvroSnapshotDeserializer = { val obj = new AvroSnapshotDeserializer() obj.init(config.asJava) obj } @transient private lazy val srSchema = fetchSchemaVersionInfo(schemaName, version) @transient private lazy val avroSchema = new Schema.Parser().parse(srSchema.getSchemaText) override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType @transient private lazy val avroDeser= new AvroDeserializer(avroSchema, dataType) override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = { val binary = input.asInstanceOf[Array[Byte]] val row = avroDeser.deserialize(srDeser.deserialize(new ByteArrayInputStream(binary), srSchema.getVersion)) val result = row match { case r: InternalRow => r.copy() case _ => row } result } override def simpleString: String = { s"from_sr(${child.sql}, ${dataType.simpleString})" } override def sql: String = { s"from_sr(${child.sql}, ${dataType.catalogString})" } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val expr = ctx.addReferenceObj("this", this) defineCodeGen(ctx, ev, input => s"(${ctx.boxedType(dataType)})$expr.nullSafeEval($input)") } private def fetchSchemaVersionInfo(schemaName: String, version: Option[Int]): SchemaVersionInfo = { val srClient = new SchemaRegistryClient(config.asJava) version.map(v => srClient.getSchemaVersionInfo(new SchemaVersionKey(schemaName, v))) .getOrElse(srClient.getLatestSchemaVersionInfo(schemaName)) } }
Example 2
Source File: ReferenceToExpressions.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.sql.types.DataType case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) extends Expression { override def nullable: Boolean = result.nullable override def dataType: DataType = result.dataType override def checkInputDataTypes(): TypeCheckResult = { if (result.references.nonEmpty) { return TypeCheckFailure("The result expression cannot reference to any attributes.") } var maxOrdinal = -1 result foreach { case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal case _ => } if (maxOrdinal > children.length) { return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + s"there are only ${children.length} inputs.") } TypeCheckSuccess } private lazy val projection = UnsafeProjection.create(children) override def eval(input: InternalRow): Any = { result.eval(projection(input)) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childrenGen = children.map(_.genCode(ctx)) val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map { case (childGen, child) => // SPARK-18125: The children vars are local variables. If the result expression uses // splitExpression, those variables cannot be accessed so compilation fails. // To fix it, we use class variables to hold those local variables. val classChildVarName = ctx.freshName("classChildVar") val classChildVarIsNull = ctx.freshName("classChildVarIsNull") ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "") ctx.addMutableState("boolean", classChildVarIsNull, "") val classChildVar = LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType) val initCode = s"${classChildVar.value} = ${childGen.value};\n" + s"${classChildVar.isNull} = ${childGen.isNull};" (classChildVar, initCode) }.unzip val resultGen = result.transform { case b: BoundReference => classChildrenVars(b.ordinal) }.genCode(ctx) ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") + resultGen.code, isNull = resultGen.isNull, value = resultGen.value) } }
Example 3
Source File: BoundAttribute.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { if (input.isNullAt(ordinal)) { null } else { dataType match { case BooleanType => input.getBoolean(ordinal) case ByteType => input.getByte(ordinal) case ShortType => input.getShort(ordinal) case IntegerType | DateType => input.getInt(ordinal) case LongType | TimestampType => input.getLong(ordinal) case FloatType => input.getFloat(ordinal) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) case CalendarIntervalType => input.getInterval(ordinal) case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) case t: StructType => input.getStruct(ordinal, t.size) case _: ArrayType => input.getArray(ordinal) case _: MapType => input.getMap(ordinal) case _ => input.get(ordinal, dataType) } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value val code = oev.code oev.code = "" ev.copy(code = code) } else if (nullable) { ev.copy(code = s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") } else { ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") } } } object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, input: AttributeSeq, allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexOf(a.exprId) if (ordinal == -1) { if (allowFailures) { a } else { sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } } else { BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } }
Example 4
Source File: decimalExpressions.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = { val d = input.asInstanceOf[Decimal].clone() if (d.changePrecision(dataType.precision, dataType.scale)) { d } else { null } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" | Decimal $tmp = $eval.clone(); | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { | ${ev.value} = $tmp; | } else { | ${ev.isNull} = true; | } """.stripMargin }) } override def toString: String = s"CheckOverflow($child, $dataType)" override def sql: String = child.sql }
Example 5
Source File: ReferenceToExpressions.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.sql.types.DataType case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) extends Expression { override def nullable: Boolean = result.nullable override def dataType: DataType = result.dataType override def checkInputDataTypes(): TypeCheckResult = { if (result.references.nonEmpty) { return TypeCheckFailure("The result expression cannot reference to any attributes.") } var maxOrdinal = -1 result foreach { case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal case _ => } if (maxOrdinal > children.length) { return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + s"there are only ${children.length} inputs.") } TypeCheckSuccess } private lazy val projection = UnsafeProjection.create(children) override def eval(input: InternalRow): Any = { result.eval(projection(input)) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childrenGen = children.map(_.genCode(ctx)) val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map { case (childGen, child) => // SPARK-18125: The children vars are local variables. If the result expression uses // splitExpression, those variables cannot be accessed so compilation fails. // To fix it, we use class variables to hold those local variables. val classChildVarName = ctx.freshName("classChildVar") val classChildVarIsNull = ctx.freshName("classChildVarIsNull") ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "") ctx.addMutableState("boolean", classChildVarIsNull, "") val classChildVar = LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType) val initCode = s"${classChildVar.value} = ${childGen.value};\n" + s"${classChildVar.isNull} = ${childGen.isNull};" (classChildVar, initCode) }.unzip val resultGen = result.transform { case b: BoundReference => classChildrenVars(b.ordinal) }.genCode(ctx) ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") + resultGen.code, isNull = resultGen.isNull, value = resultGen.value) } }
Example 6
Source File: ExpressionEvalHelperSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} case class BadCodegenExpression() extends LeafExpression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) } override def dataType: DataType = IntegerType }
Example 7
Source File: CheckDeltaInvariant.scala From delta with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.delta.schema import org.apache.spark.sql.delta.schema.Invariants.{ArbitraryExpression, NotNull} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, ExprCode, JavaCode, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, NullType} case class CheckDeltaInvariant( child: Expression, invariant: Invariant) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = NullType override def foldable: Boolean = false override def nullable: Boolean = true override def flatArguments: Iterator[Any] = Iterator(child) private def assertRule(input: InternalRow): Unit = invariant.rule match { case NotNull if child.eval(input) == null => throw InvariantViolationException(invariant, "") case ArbitraryExpression(expr) => val resolvedExpr = expr.transform { case _: UnresolvedAttribute => child } val result = resolvedExpr.eval(input) if (result == null || result == false) { throw InvariantViolationException( invariant, s"Value ${child.eval(input)} violates requirement.") } } override def eval(input: InternalRow): Any = { assertRule(input) null } private def generateNotNullCode(ctx: CodegenContext): Block = { val childGen = child.genCode(ctx) val invariantField = ctx.addReferenceObj("errMsg", invariant) code"""${childGen.code} | |if (${childGen.isNull}) { | throw org.apache.spark.sql.delta.schema.InvariantViolationException.apply( | $invariantField, ""); |} """.stripMargin } private def generateExpressionValidationCode(expr: Expression, ctx: CodegenContext): Block = { val resolvedExpr = expr.transform { case _: UnresolvedAttribute => child } val elementValue = child.genCode(ctx) val childGen = resolvedExpr.genCode(ctx) val invariantField = ctx.addReferenceObj("errMsg", invariant) val eValue = ctx.freshName("elementResult") code"""${elementValue.code} |${childGen.code} | |if (${childGen.isNull} || ${childGen.value} == false) { | Object $eValue = "null"; | if (!${elementValue.isNull}) { | $eValue = (Object) ${elementValue.value}; | } | throw org.apache.spark.sql.delta.schema.InvariantViolationException.apply( | $invariantField, "Value " + $eValue + " violates requirement."); |} """.stripMargin } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val code = invariant.rule match { case NotNull => generateNotNullCode(ctx) case ArbitraryExpression(expr) => generateExpressionValidationCode(expr, ctx) } ev.copy(code = code, isNull = TrueLiteral, value = JavaCode.literal("null", NullType)) } }
Example 8
Source File: Serialize.scala From morpheus with Apache License 2.0 | 5 votes |
package org.opencypher.morpheus.impl.expressions import java.io.ByteArrayOutputStream import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, _} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.opencypher.morpheus.impl.expressions.EncodeLong.encodeLong import org.opencypher.morpheus.impl.expressions.Serialize._ import org.opencypher.okapi.impl.exception case class Serialize(children: Seq[Expression]) extends Expression { override def dataType: DataType = BinaryType override def nullable: Boolean = false // TODO: Only write length if more than one column is serialized override def eval(input: InternalRow): Any = { // TODO: Reuse from a pool instead of allocating a new one for each serialization val out = new ByteArrayOutputStream() children.foreach { child => child.dataType match { case BinaryType => write(child.eval(input).asInstanceOf[Array[Byte]], out) case StringType => write(child.eval(input).asInstanceOf[UTF8String], out) case IntegerType => write(child.eval(input).asInstanceOf[Int], out) case LongType => write(child.eval(input).asInstanceOf[Long], out) case other => throw exception.UnsupportedOperationException(s"Cannot serialize Spark data type $other.") } } out.toByteArray } override protected def doGenCode( ctx: CodegenContext, ev: ExprCode ): ExprCode = { ev.isNull = FalseLiteral val out = ctx.freshName("out") val serializeChildren = children.map { child => val childEval = child.genCode(ctx) s"""|${childEval.code} |if (!${childEval.isNull}) { | ${Serialize.getClass.getName.dropRight(1)}.write(${childEval.value}, $out); |}""".stripMargin }.mkString("\n") val baos = classOf[ByteArrayOutputStream].getName ev.copy( code = code"""|$baos $out = new $baos(); |$serializeChildren |byte[] ${ev.value} = $out.toByteArray();""".stripMargin) } } object Serialize { val supportedTypes: Set[DataType] = Set(BinaryType, StringType, IntegerType, LongType) @inline final def write(value: Array[Byte], out: ByteArrayOutputStream): Unit = { out.write(encodeLong(value.length)) out.write(value) } @inline final def write( value: Boolean, out: ByteArrayOutputStream ): Unit = write(if (value) 1.toLong else 0.toLong, out) @inline final def write(value: Byte, out: ByteArrayOutputStream): Unit = write(value.toLong, out) @inline final def write(value: Int, out: ByteArrayOutputStream): Unit = write(value.toLong, out) @inline final def write(value: Long, out: ByteArrayOutputStream): Unit = write(encodeLong(value), out) @inline final def write(value: UTF8String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out) @inline final def write(value: String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out) }
Example 9
Source File: EncodeLong.scala From morpheus with Apache License 2.0 | 5 votes |
package org.opencypher.morpheus.impl.expressions import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant, UnaryExpression} import org.apache.spark.sql.types.{BinaryType, DataType, LongType} import org.opencypher.morpheus.api.value.MorpheusElement._ case class EncodeLong(child: Expression) extends UnaryExpression with NullIntolerant with ExpectsInputTypes { override val dataType: DataType = BinaryType override val inputTypes: Seq[LongType] = Seq(LongType) override protected def nullSafeEval(input: Any): Any = EncodeLong.encodeLong(input.asInstanceOf[Long]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = defineCodeGen(ctx, ev, c => s"(byte[])(${EncodeLong.getClass.getName.dropRight(1)}.encodeLong($c))") } object EncodeLong { private final val moreBytesBitMask: Long = Integer.parseInt("10000000", 2) private final val varLength7BitMask: Long = Integer.parseInt("01111111", 2) private final val otherBitsMask = ~varLength7BitMask private final val maxBytesForLongVarEncoding = 10 // Same encoding as as Base 128 Varints @ https://developers.google.com/protocol-buffers/docs/encoding @inline final def encodeLong(l: Long): Array[Byte] = { val tempResult = new Array[Byte](maxBytesForLongVarEncoding) var remainder = l var index = 0 while ((remainder & otherBitsMask) != 0) { tempResult(index) = ((remainder & varLength7BitMask) | moreBytesBitMask).toByte remainder >>>= 7 index += 1 } tempResult(index) = remainder.toByte val result = new Array[Byte](index + 1) System.arraycopy(tempResult, 0, result, 0, index + 1) result } // Same encoding as as Base 128 Varints @ https://developers.google.com/protocol-buffers/docs/encoding @inline final def decodeLong(input: Array[Byte]): Long = { assert(input.nonEmpty, "`decodeLong` requires a non-empty array as its input") var index = 0 var currentByte = input(index) var decoded = currentByte & varLength7BitMask var nextLeftShift = 7 while ((currentByte & moreBytesBitMask) != 0) { index += 1 currentByte = input(index) decoded |= (currentByte & varLength7BitMask) << nextLeftShift nextLeftShift += 7 } assert(index == input.length - 1, s"`decodeLong` received an input array ${input.toSeq.toHex} with extra bytes that could not be decoded.") decoded } implicit class ColumnLongOps(val c: Column) extends AnyVal { def encodeLongAsMorpheusId(name: String): Column = encodeLongAsMorpheusId.as(name) def encodeLongAsMorpheusId: Column = new Column(EncodeLong(c.expr)) } }
Example 10
Source File: MonotonicallyIncreasingID.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} @transient private[this] var count: Long = _ @transient private[this] var partitionMask: Long = _ override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false override def dataType: DataType = LongType override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" }
Example 11
Source File: BoundAttribute.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { if (input.isNullAt(ordinal)) { null } else { dataType match { case BooleanType => input.getBoolean(ordinal) case ByteType => input.getByte(ordinal) case ShortType => input.getShort(ordinal) case IntegerType | DateType => input.getInt(ordinal) case LongType | TimestampType => input.getLong(ordinal) case FloatType => input.getFloat(ordinal) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) case CalendarIntervalType => input.getInterval(ordinal) case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) case t: StructType => input.getStruct(ordinal, t.size) case _: ArrayType => input.getArray(ordinal) case _: MapType => input.getMap(ordinal) case _ => input.get(ordinal, dataType) } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value val code = oev.code oev.code = "" ev.copy(code = code) } else if (nullable) { ev.copy(code = s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") } else { ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") } } } object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, input: AttributeSeq, allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexOf(a.exprId) if (ordinal == -1) { if (allowFailures) { a } else { sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } } else { BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } }
Example 12
Source File: decimalExpressions.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = { val d = input.asInstanceOf[Decimal].clone() if (d.changePrecision(dataType.precision, dataType.scale)) { d } else { null } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" | Decimal $tmp = $eval.clone(); | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { | ${ev.value} = $tmp; | } else { | ${ev.isNull} = true; | } """.stripMargin }) } override def toString: String = s"CheckOverflow($child, $dataType)" override def sql: String = child.sql }
Example 13
Source File: MonotonicallyIncreasingID.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} @transient private[this] var count: Long = _ @transient private[this] var partitionMask: Long = _ override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false override def dataType: DataType = LongType override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" }
Example 14
Source File: ExpressionEvalHelperSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} case class BadCodegenExpression() extends LeafExpression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) } override def dataType: DataType = IntegerType }
Example 15
Source File: MonotonicallyIncreasingID.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} @transient private[this] var count: Long = _ @transient private[this] var partitionMask: Long = _ override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false override def dataType: DataType = LongType override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" }
Example 16
Source File: BoundAttribute.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { if (input.isNullAt(ordinal)) { null } else { dataType match { case BooleanType => input.getBoolean(ordinal) case ByteType => input.getByte(ordinal) case ShortType => input.getShort(ordinal) case IntegerType | DateType => input.getInt(ordinal) case LongType | TimestampType => input.getLong(ordinal) case FloatType => input.getFloat(ordinal) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) case CalendarIntervalType => input.getInterval(ordinal) case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) case t: StructType => input.getStruct(ordinal, t.size) case _: ArrayType => input.getArray(ordinal) case _: MapType => input.getMap(ordinal) case _ => input.get(ordinal, dataType) } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = s""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """.stripMargin) } else { ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") } } } } object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, input: AttributeSeq, allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexOf(a.exprId) if (ordinal == -1) { if (allowFailures) { a } else { sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } } else { BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } }
Example 17
Source File: decimalExpressions.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" | Decimal $tmp = $eval.clone(); | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { | ${ev.value} = $tmp; | } else { | ${ev.isNull} = true; | } """.stripMargin }) } override def toString: String = s"CheckOverflow($child, $dataType)" override def sql: String = child.sql }
Example 18
Source File: ExpressionEvalHelperSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} case class BadCodegenExpression() extends LeafExpression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) } override def dataType: DataType = IntegerType }
Example 19
Source File: CatalystDataToAvro.scala From spark-schema-registry with Apache License 2.0 | 5 votes |
package com.hortonworks.spark.registry.avro import com.hortonworks.registries.schemaregistry.{SchemaCompatibility, SchemaMetadata} import com.hortonworks.registries.schemaregistry.avro.AvroSchemaProvider import com.hortonworks.registries.schemaregistry.client.SchemaRegistryClient import com.hortonworks.registries.schemaregistry.serdes.avro.AvroSnapshotSerializer import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{BinaryType, DataType} import scala.collection.JavaConverters._ case class CatalystDataToAvro( child: Expression, schemaName: String, recordName: String, nameSpace: String, config: Map[String, Object] ) extends UnaryExpression { override def dataType: DataType = BinaryType private val topLevelRecordName = if (recordName == "") schemaName else recordName @transient private lazy val avroType = SchemaConverters.toAvroType(child.dataType, child.nullable, topLevelRecordName, nameSpace) @transient private lazy val avroSer = new AvroSerializer(child.dataType, avroType, child.nullable) @transient private lazy val srSer: AvroSnapshotSerializer = { val obj = new AvroSnapshotSerializer() obj.init(config.asJava) obj } @transient private lazy val srClient = new SchemaRegistryClient(config.asJava) @transient private lazy val schemaMetadata = { var schemaMetadataInfo = srClient.getSchemaMetadataInfo(schemaName) if (schemaMetadataInfo == null) { val generatedSchemaMetadata = new SchemaMetadata.Builder(schemaName). `type`(AvroSchemaProvider.TYPE) .schemaGroup("Autogenerated group") .description("Autogenerated schema") .compatibility(SchemaCompatibility.BACKWARD).build srClient.addSchemaMetadata(generatedSchemaMetadata) generatedSchemaMetadata } else { schemaMetadataInfo.getSchemaMetadata } } override def nullSafeEval(input: Any): Any = { val avroData = avroSer.serialize(input) srSer.serialize(avroData.asInstanceOf[Object], schemaMetadata) } override def simpleString: String = { s"to_sr(${child.sql}, ${child.dataType.simpleString})" } override def sql: String = { s"to_sr(${child.sql}, ${child.dataType.catalogString})" } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val expr = ctx.addReferenceObj("this", this) defineCodeGen(ctx, ev, input => s"(byte[]) $expr.nullSafeEval($input)") } }
Example 20
Source File: MeanSubstitute.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.sql.expressions import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String import io.projectglow.sql.dsl._ import io.projectglow.sql.util.RewriteAfterResolution case class MeanSubstitute(array: Expression, missingValue: Expression) extends RewriteAfterResolution { override def children: Seq[Expression] = Seq(array, missingValue) def this(array: Expression) = { this(array, Literal(-1)) } private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType // A value is considered missing if it is NaN, null or equal to the missing value parameter def isMissing(arrayElement: Expression): Predicate = IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = { val sumName = Literal(UTF8String.fromString("sum"), StringType) val countName = Literal(UTF8String.fromString("count"), StringType) namedStruct(sumName, sumValue, countName, countValue) } // Update sum and count with array element if not missing def updateSumAndCountConditionally( stateStruct: Expression, arrayElement: Expression): Expression = { If( isMissing(arrayElement), // If value is missing, do not update sum and count stateStruct, // If value is not missing, add to sum and increment count createNamedStruct( stateStruct.getField("sum") + arrayElement, stateStruct.getField("count") + 1) ) } // Calculate mean for imputation def calculateMean(stateStruct: Expression): Expression = { If( stateStruct.getField("count") > 0, // If non-missing values were found, calculate the average stateStruct.getField("sum") / stateStruct.getField("count"), // If all values were missing, substitute with missing value missingValue ) } lazy val arrayMean: Expression = { // Sum and count of non-missing values array.aggregate( createNamedStruct(Literal(0d), Literal(0L)), updateSumAndCountConditionally, calculateMean ) } def substituteWithMean(arrayElement: Expression): Expression = { If(isMissing(arrayElement), arrayMean, arrayElement) } override def rewrite: Expression = { if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) { throw SQLUtils.newAnalysisException( s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.") } if (!missingValue.dataType.isInstanceOf[NumericType]) { throw SQLUtils.newAnalysisException( s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.") } // Replace missing values with the provided strategy array.arrayTransform(substituteWithMean(_)) } }
Example 21
Source File: LinearRegressionExpr.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.sql.expressions import breeze.linalg.DenseVector import org.apache.spark.TaskContext import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, TernaryExpression} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ object LinearRegressionExpr { private val matrixUDT = SQLUtils.newMatrixUDT() private val state = new ThreadLocal[CovariateQRContext] def doLinearRegression(genotypes: Any, phenotypes: Any, covariates: Any): InternalRow = { if (state.get() == null) { // Save the QR factorization of the covariate matrix since it's the same for every row state.set(CovariateQRContext.computeQR(matrixUDT.deserialize(covariates).toDense)) TaskContext.get().addTaskCompletionListener[Unit](_ => state.remove()) } LinearRegressionGwas.linearRegressionGwas( new DenseVector[Double](genotypes.asInstanceOf[ArrayData].toDoubleArray()), new DenseVector[Double](phenotypes.asInstanceOf[ArrayData].toDoubleArray()), state.get() ) } } case class LinearRegressionExpr( genotypes: Expression, phenotypes: Expression, covariates: Expression) extends TernaryExpression with ImplicitCastInputTypes { private val matrixUDT = SQLUtils.newMatrixUDT() override def dataType: DataType = StructType( Seq( StructField("beta", DoubleType), StructField("standardError", DoubleType), StructField("pValue", DoubleType))) override def inputTypes: Seq[DataType] = Seq(ArrayType(DoubleType), ArrayType(DoubleType), matrixUDT) override def children: Seq[Expression] = Seq(genotypes, phenotypes, covariates) override protected def nullSafeEval(genotypes: Any, phenotypes: Any, covariates: Any): Any = { LinearRegressionExpr.doLinearRegression(genotypes, phenotypes, covariates) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen( ctx, ev, (genotypes, phenotypes, covariates) => { s""" |${ev.value} = io.projectglow.sql.expressions.LinearRegressionExpr.doLinearRegression($genotypes, $phenotypes, $covariates); """.stripMargin } ) } }
Example 22
Source File: inputFileBlock.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @ExpressionDescription( usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.") case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = StringType override def prettyName: String = "input_file_name" override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileBlockHolder.getInputFilePath } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") val typeDef = s"final ${CodeGenerator.javaType(dataType)}" ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", isNull = FalseLiteral) } } @ExpressionDescription( usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.") case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = LongType override def prettyName: String = "input_file_block_start" override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Long = { InputFileBlockHolder.getStartOffset } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") val typeDef = s"final ${CodeGenerator.javaType(dataType)}" ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @ExpressionDescription( usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.") case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = LongType override def prettyName: String = "input_file_block_length" override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Long = { InputFileBlockHolder.getLength } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") val typeDef = s"final ${CodeGenerator.javaType(dataType)}" ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } }
Example 23
Source File: MonotonicallyIncreasingID.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} @transient private[this] var count: Long = _ @transient private[this] var partitionMask: Long = _ override protected def initInternal(): Unit = { count = 0L partitionMask = TaskContext.getPartitionId().toLong << 33 } override def nullable: Boolean = false override def dataType: DataType = LongType override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" }
Example 24
Source File: randomExpressions.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @ExpressionDescription( usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.") case class Randn(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() def this() = this(Utils.random.nextLong()) def this(seed: Expression) = this(seed match { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") }) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } }
Example 25
Source File: TimeWindow.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval case class TimeWindow( timeColumn: Expression, windowDuration: Long, slideDuration: Long, startTime: Long) extends UnaryExpression with ImplicitCastInputTypes with Unevaluable with NonSQLExpression { ////////////////////////// // SQL Constructors ////////////////////////// def this( timeColumn: Expression, windowDuration: Expression, slideDuration: Expression, startTime: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime)) } def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(slideDuration), 0) } def this(timeColumn: Expression, windowDuration: Expression) = { this(timeColumn, windowDuration, windowDuration) } override def child: Expression = timeColumn override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = new StructType() .add(StructField("start", TimestampType)) .add(StructField("end", TimestampType)) // This expression is replaced in the analyzer. override lazy val resolved = false case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = LongType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } }
Example 26
Source File: BoundAttribute.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { if (input.isNullAt(ordinal)) { null } else { dataType match { case BooleanType => input.getBoolean(ordinal) case ByteType => input.getByte(ordinal) case ShortType => input.getShort(ordinal) case IntegerType | DateType => input.getInt(ordinal) case LongType | TimestampType => input.getLong(ordinal) case FloatType => input.getFloat(ordinal) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) case CalendarIntervalType => input.getInterval(ordinal) case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) case t: StructType => input.getStruct(ordinal, t.size) case _: ArrayType => input.getArray(ordinal) case _: MapType => input.getMap(ordinal) case _ => input.get(ordinal, dataType) } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value val code = oev.code oev.code = "" ev.copy(code = code) } else if (nullable) { ev.copy(code = s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") } else { ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") } } } object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, input: AttributeSeq, allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexOf(a.exprId) if (ordinal == -1) { if (allowFailures) { a } else { sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } } else { BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } }
Example 27
Source File: decimalExpressions.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = { val d = input.asInstanceOf[Decimal].clone() if (d.changePrecision(dataType.precision, dataType.scale)) { d } else { null } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" | Decimal $tmp = $eval.clone(); | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { | ${ev.value} = $tmp; | } else { | ${ev.isNull} = true; | } """.stripMargin }) } override def toString: String = s"CheckOverflow($child, $dataType)" override def sql: String = child.sql }
Example 28
Source File: ReferenceToExpressions.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.sql.types.DataType case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) extends Expression { override def nullable: Boolean = result.nullable override def dataType: DataType = result.dataType override def checkInputDataTypes(): TypeCheckResult = { if (result.references.nonEmpty) { return TypeCheckFailure("The result expression cannot reference to any attributes.") } var maxOrdinal = -1 result foreach { case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal case _ => } if (maxOrdinal > children.length) { return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + s"there are only ${children.length} inputs.") } TypeCheckSuccess } private lazy val projection = UnsafeProjection.create(children) override def eval(input: InternalRow): Any = { result.eval(projection(input)) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childrenGen = children.map(_.genCode(ctx)) val childrenVars = childrenGen.zip(children).map { case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) } val resultGen = result.transform { case b: BoundReference => childrenVars(b.ordinal) }.genCode(ctx) ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code, isNull = resultGen.isNull, value = resultGen.value) } }
Example 29
Source File: ExpressionEvalHelperSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} case class BadCodegenExpression() extends LeafExpression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) } override def dataType: DataType = IntegerType }
Example 30
Source File: package.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import java.util.Collections import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.util.{AccumulatorV2, LongAccumulator} case class ColumnMetrics() { val elementTypes = new SetAccumulator[String] sparkContext.register(elementTypes) } val tupleCount: LongAccumulator = sparkContext.longAccumulator val numColumns: Int = child.output.size val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { debugPrint(s"== ${child.simpleString} ==") debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => // This is called on driver. All accumulator updates have a fixed value. So it's safe to use // `asScala` which accesses the internal values using `java.util.Iterator`. val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}") debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext def next(): InternalRow = { val currentRow = iter.next() tupleCount.add(1) var i = 0 while (i < numColumns) { val value = currentRow.get(i, output(i).dataType) if (value != null) { columnStats(i).elementTypes.add(value.getClass.getName) } i += 1 } currentRow } } } } override def outputPartitioning: Partitioning = child.outputPartitioning override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() } override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { consume(ctx, input) } } }
Example 31
Source File: subquery.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]()) val sameResult = sameSchema.find(_.sameResult(sub.plan)) if (sameResult.isDefined) { sub.withNewPlan(sameResult.get) } else { sameSchema += sub.plan sub } } } }
Example 32
Source File: TimestampCast.scala From flint with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.codegen.{ CodegenContext, ExprCode, CodeGenerator, JavaCode, Block } import org.apache.spark.sql.catalyst.expressions.{ Expression, NullIntolerant, UnaryExpression } import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{ DataType, LongType, TimestampType } case class TimestampToNanos(child: Expression) extends TimestampCast { val dataType: DataType = LongType protected def cast(childPrim: String): String = s"$childPrim * 1000L" override protected def nullSafeEval(input: Any): Any = input.asInstanceOf[Long] * 1000L } case class NanosToTimestamp(child: Expression) extends TimestampCast { val dataType: DataType = TimestampType protected def cast(childPrim: String): String = s"$childPrim / 1000L" override protected def nullSafeEval(input: Any): Any = input.asInstanceOf[Long] / 1000L } object TimestampToNanos { private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, resultPrim: String, resultNull: String, resultType: DataType): Block = { code""" boolean $resultNull = $childNull; ${CodeGenerator.javaType(resultType)} $resultPrim = ${CodeGenerator.defaultValue(resultType)}; if (!${childNull}) { $resultPrim = (long) ${cast(childPrim)}; } """ } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType)) } }
Example 33
Source File: InputFileName.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileNameHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @ExpressionDescription( usage = "_FUNC_() - Returns the name of the current file being read if available", extended = "> SELECT _FUNC_();\n ''") case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = true override def dataType: DataType = StringType override def prettyName: String = "input_file_name" override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileNameHolder.getInputFileName() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false") } }
Example 34
Source File: MonotonicallyIncreasingID.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} @transient private[this] var count: Long = _ @transient private[this] var partitionMask: Long = _ override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false override def dataType: DataType = LongType override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID() }
Example 35
Source File: randomExpressions.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom // scalastyle:off line.size.limit @ExpressionDescription( usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""", examples = """ Examples: > SELECT _FUNC_(); -0.3254147983080288 > SELECT _FUNC_(0); 1.1164209726833079 > SELECT _FUNC_(null); 1.1164209726833079 """, note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed { def this() = this(Literal(Utils.random.nextLong(), LongType)) override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType)) override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = classOf[XORShiftRandom].getName val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } override def freshCopy(): Randn = Randn(child) } object Randn { def apply(seed: Long): Randn = Randn(Literal(seed, LongType)) }
Example 36
Source File: TimeWindow.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval case class TimeWindow( timeColumn: Expression, windowDuration: Long, slideDuration: Long, startTime: Long) extends UnaryExpression with ImplicitCastInputTypes with Unevaluable with NonSQLExpression { ////////////////////////// // SQL Constructors ////////////////////////// def this( timeColumn: Expression, windowDuration: Expression, slideDuration: Expression, startTime: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime)) } def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(slideDuration), 0) } def this(timeColumn: Expression, windowDuration: Expression) = { this(timeColumn, windowDuration, windowDuration) } override def child: Expression = timeColumn override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = new StructType() .add(StructField("start", TimestampType)) .add(StructField("end", TimestampType)) // This expression is replaced in the analyzer. override lazy val resolved = false case class PreciseTimestampConversion( child: Expression, fromType: DataType, toType: DataType) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(fromType) override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input }
Example 37
Source File: constraintExpressions.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types.DataType case class KnownNotNull(child: Expression) extends UnaryExpression { override def nullable: Boolean = false override def dataType: DataType = child.dataType override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx).copy(isNull = FalseLiteral) } override def eval(input: InternalRow): Any = { child.eval(input) } }
Example 38
Source File: BoundAttribute.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { if (nullable && input.isNullAt(ordinal)) { null } else { accessor(input, ordinal) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") val javaType = JavaCode.javaType(dataType) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } } object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, input: AttributeSeq, allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexOf(a.exprId) if (ordinal == -1) { if (allowFailures) { a } else { sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } } else { BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } }
Example 39
Source File: decimalExpressions.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" | Decimal $tmp = $eval.clone(); | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { | ${ev.value} = $tmp; | } else { | ${ev.isNull} = true; | } """.stripMargin }) } override def toString: String = s"CheckOverflow($child, $dataType)" override def sql: String = child.sql }
Example 40
Source File: ExpressionEvalHelperSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} case class BadCodegenExpression() extends LeafExpression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) } override def dataType: DataType = IntegerType }
Example 41
Source File: subquery.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]()) val sameResult = sameSchema.find(_.sameResult(sub.plan)) if (sameResult.isDefined) { sub.withNewPlan(sameResult.get) } else { sameSchema += sub.plan sub } } } }
Example 42
Source File: ColumnarSubquery.scala From OAP with Apache License 2.0 | 5 votes |
package com.intel.sparkColumnarPlugin.expression import org.apache.arrow.gandiva.evaluator._ import org.apache.arrow.gandiva.exceptions.GandivaException import org.apache.arrow.gandiva.expression._ import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.arrow.vector.types.pojo.Field import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.BaseSubqueryExec import org.apache.spark.sql.execution.ExecSubqueryExpression import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.types._ import scala.collection.mutable.ListBuffer class ColumnarScalarSubquery( query: ScalarSubquery) extends Expression with ColumnarExpression { override def dataType: DataType = query.dataType override def children: Seq[Expression] = Nil override def nullable: Boolean = true override def toString: String = query.toString override def eval(input: InternalRow): Any = query.eval(input) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = query.doGenCode(ctx, ev) override def canEqual(that: Any): Boolean = query.canEqual(that) override def productArity: Int = query.productArity override def productElement(n: Int): Any = query.productElement(n) override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { val value = query.eval(null) val resultType = CodeGeneration.getResultType(query.dataType) query.dataType match { case t: StringType => (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType) case t: IntegerType => (TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType) case t: LongType => (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Long]), resultType) case t: DoubleType => (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType) case d: DecimalType => val v = value.asInstanceOf[Decimal] (TreeBuilder.makeDecimalLiteral(v.toString, v.precision, v.scale), resultType) case d: DateType => throw new UnsupportedOperationException(s"DateType is not supported yet.") } } }