org.apache.spark.sql.catalyst.expressions.Cast Scala Examples
The following examples show how to use org.apache.spark.sql.catalyst.expressions.Cast.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: MessageDelimiter.scala From spark-cep with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.streaming.sources import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, EmptyRow, Literal} import org.apache.spark.sql.types.StructType class MessageDelimiter extends MessageToRowConverter with Logging { val delimiter = " " def toRow(msg: String, schema: StructType): InternalRow = { val splitted = msg.split(delimiter).map(Literal(_)) val casted = splitted.indices.map(i => Cast(splitted(i), schema(i).dataType).eval(EmptyRow)) InternalRow.fromSeq(casted) } def toMessage(row: Row): String = row.mkString(delimiter) } trait MessageToRowConverter extends Serializable { def toRow(message: String, schema: StructType): InternalRow def toMessage(row: Row): String }
Example 2
Source File: HiveTypeCoercionSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.Project import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") baseTypes.foreach { i => baseTypes.foreach { j => createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { i => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } }
Example 3
Source File: KinesisWriteTask.scala From kinesis-sql with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.kinesis import java.nio.ByteBuffer import com.amazonaws.services.kinesis.producer.{KinesisProducer, UserRecordResult} import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection} import org.apache.spark.sql.types.{BinaryType, StringType} private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, String], inputSchema: Seq[Attribute]) extends Logging { private var producer: KinesisProducer = _ private val projection = createProjection private val streamName = producerConfiguration.getOrElse( KinesisSourceProvider.SINK_STREAM_NAME_KEY, "") def execute(iterator: Iterator[InternalRow]): Unit = { producer = CachedKinesisProducer.getOrCreate(producerConfiguration) while (iterator.hasNext) { val currentRow = iterator.next() val projectedRow = projection(currentRow) val partitionKey = projectedRow.getString(0) val data = projectedRow.getBinary(1) sendData(partitionKey, data) } } def sendData(partitionKey: String, data: Array[Byte]): String = { var sentSeqNumbers = new String val future = producer.addUserRecord(streamName, partitionKey, ByteBuffer.wrap(data)) val kinesisCallBack = new FutureCallback[UserRecordResult]() { override def onFailure(t: Throwable): Unit = { logError(s"Writing to $streamName failed due to ${t.getCause}") } override def onSuccess(result: UserRecordResult): Unit = { val shardId = result.getShardId sentSeqNumbers = result.getSequenceNumber } } Futures.addCallback(future, kinesisCallBack) producer.flushSync() sentSeqNumbers } def close(): Unit = { if (producer != null) { producer.flush() producer = null } } private def createProjection: UnsafeProjection = { val partitionKeyExpression = inputSchema .find(_.name == KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME).getOrElse( throw new IllegalStateException("Required attribute " + s"'${KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME}' not found")) partitionKeyExpression.dataType match { case StringType | BinaryType => // ok case t => throw new IllegalStateException(s"${KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME} " + "attribute type must be a String or BinaryType") } val dataExpression = inputSchema.find(_.name == KinesisWriter.DATA_ATTRIBUTE_NAME).getOrElse( throw new IllegalStateException("Required attribute " + s"'${KinesisWriter.DATA_ATTRIBUTE_NAME}' not found") ) dataExpression.dataType match { case StringType | BinaryType => // ok case t => throw new IllegalStateException(s"${KinesisWriter.DATA_ATTRIBUTE_NAME} " + "attribute type must be a String or BinaryType") } UnsafeProjection.create( Seq(Cast(partitionKeyExpression, StringType), Cast(dataExpression, StringType)), inputSchema) } }
Example 4
Source File: ResolveInlineTablesSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) assert(ResolveInlineTables(conf)(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted = ResolveInlineTables(conf).convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) assert(converted.data(0).getLong(0) == 1L) assert(converted.data(1).getLong(0) == 2L) } test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone(conf).apply(table) val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) assert(data.size == 1) assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted1 = ResolveInlineTables(conf).convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val converted2 = ResolveInlineTables(conf).convert(table2) assert(converted2.schema.fields(0).nullable) } }
Example 5
Source File: HiveTypeCoercionSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq( ("1", "1"), ("1.0", "CAST(1.0 AS DOUBLE)"), ("1L", "1L"), ("1S", "1S"), ("1Y", "1Y"), ("'1'", "'1'")) baseTypes.foreach { case (ni, si) => baseTypes.foreach { case (nj, sj) => createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { case (i, s) => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $s else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $s end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.sparkPlan.collect { case e: ProjectExec => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } }
Example 6
Source File: HiveTypeCoercionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.Project import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") baseTypes.foreach { i => baseTypes.foreach { j => createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { i => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") } //应该删除布尔值的boolean cast test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced 没有引入表达式 project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check 只有一个平等检查 var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } }
Example 7
Source File: DateUtils.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import java.sql.Date import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast object DateUtils { private val MILLIS_PER_DAY = 86400000 // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { override protected def initialValue: TimeZone = { Calendar.getInstance.getTimeZone } } private def javaDateToDays(d: Date): Int = { millisToDays(d.getTime) } // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } private def toMillisSinceEpoch(days: Int): Long = { val millisUtc = days.toLong * MILLIS_PER_DAY millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) } def fromJavaDate(date: java.sql.Date): Int = { javaDateToDays(date) } def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) } def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) def stringToTime(s: String): java.util.Date = { if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { java.sql.Timestamp.valueOf(s) } else { java.sql.Date.valueOf(s) } } else if (s.endsWith("Z")) { // this is zero timezone of ISO8601 stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") } else if (s.indexOf("GMT") == -1) { // timezone with ISO8601 val inset = "+00.00".length val s0 = s.substring(0, s.length - inset) val s1 = s.substring(s.length - inset, s.length) if (s0.substring(s0.lastIndexOf(':')).contains('.')) { stringToTime(s0 + "GMT" + s1) } else { stringToTime(s0 + ".0GMT" + s1) } } else { // ISO8601 with GMT insert val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) ISO8601GMT.parse(s) } } }
Example 8
Source File: HiveTypeCoercionSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.Project import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") baseTypes.foreach { i => baseTypes.foreach { j => createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { i => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } test("COALESCE with different types") { intercept[RuntimeException] { TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() } } }
Example 9
Source File: ResolveInlineTables.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{StructField, StructType} private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = { // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) } val attributes = StructType(fields).toAttributes assert(fields.size == table.names.size) val newRows: Seq[InternalRow] = table.rows.map { row => InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => val targetType = fields(ci).dataType try { if (e.dataType.sameType(targetType)) { e.eval() } else { Cast(e, targetType).eval() } } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") } }) } LocalRelation(attributes, newRows) } }
Example 10
Source File: HiveTypeCoercionSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq( ("1", "1"), ("1.0", "CAST(1.0 AS DOUBLE)"), ("1L", "1L"), ("1S", "1S"), ("1Y", "1Y"), ("'1'", "'1'")) baseTypes.foreach { case (ni, si) => baseTypes.foreach { case (nj, sj) => createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { case (i, s) => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $s else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $s end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.sparkPlan.collect { case e: ProjectExec => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } }
Example 11
Source File: HiveTypeCoercionSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq( ("1", "1"), ("1.0", "CAST(1.0 AS DOUBLE)"), ("1L", "1L"), ("1S", "1S"), ("1Y", "1Y"), ("'1'", "'1'")) baseTypes.foreach { case (ni, si) => baseTypes.foreach { case (nj, sj) => createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { case (i, s) => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $s else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $s end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.sparkPlan.collect { case e: ProjectExec => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } }
Example 12
Source File: CarbonExpressions.scala From carbondata with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, ScalaUDF} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.types.DataType object CarbonScalaUDF { def unapply(expression: Expression): Option[(ScalaUDF)] = { expression match { case a: ScalaUDF => Some(a) case _ => None } } } }
Example 13
Source File: ResolveInlineTables.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{StructField, StructType} private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = { // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) } val attributes = StructType(fields).toAttributes assert(fields.size == table.names.size) val newRows: Seq[InternalRow] = table.rows.map { row => InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => val targetType = fields(ci).dataType try { if (e.dataType.sameType(targetType)) { e.eval() } else { Cast(e, targetType).eval() } } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") } }) } LocalRelation(attributes, newRows) } }
Example 14
Source File: HiveTypeCoercionSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq( ("1", "1"), ("1.0", "CAST(1.0 AS DOUBLE)"), ("1L", "1L"), ("1S", "1S"), ("1Y", "1Y"), ("'1'", "'1'")) baseTypes.foreach { case (ni, si) => baseTypes.foreach { case (nj, sj) => createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { case (i, s) => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $s else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $s end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.sparkPlan.collect { case e: ProjectExec => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } }
Example 15
Source File: GlobalSapSQLContext.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import java.io.File import com.sap.spark.util.TestUtils import com.sap.spark.{GlobalSparkContext, WithSQLContext} import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} import org.apache.spark.unsafe.types._ import org.apache.spark.sql.types._ import org.scalatest.Suite import scala.io.Source trait GlobalSapSQLContext extends GlobalSparkContext with WithSQLContext { self: Suite => override implicit def sqlContext: SQLContext = GlobalSapSQLContext._sqlc override protected def setUpSQLContext(): Unit = GlobalSapSQLContext.init(sc) override protected def tearDownSQLContext(): Unit = GlobalSapSQLContext.reset() def getDataFrameFromSourceFile(sparkSchema: StructType, path: File): DataFrame = { val conversions = sparkSchema.toSeq.zipWithIndex.map({ case (field, index) => Cast(BoundReference(index, StringType, nullable = true), field.dataType) }) val data = Source.fromFile(path) .getLines() .map({ line => val stringRow = InternalRow.fromSeq(line.split(",", -1).map(UTF8String.fromString)) Row.fromSeq(conversions.map({ c => c.eval(stringRow) })) }) val rdd = sc.parallelize(data.toSeq, numberOfSparkWorkers) sqlContext.createDataFrame(rdd, sparkSchema) } } object GlobalSapSQLContext { private var _sqlc: SQLContext = _ private def init(sc: SparkContext): Unit = if (_sqlc == null) { _sqlc = TestUtils.newSQLContext(sc) } private def reset(): Unit = { if (_sqlc != null) { _sqlc.catalog.unregisterAllTables() } } }
Example 16
Source File: ResolveInlineTablesSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) assert(ResolveInlineTables(conf)(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted = ResolveInlineTables(conf).convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) assert(converted.data(0).getLong(0) == 1L) assert(converted.data(1).getLong(0) == 2L) } test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone(conf).apply(table) val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) assert(data.size == 1) assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted1 = ResolveInlineTables(conf).convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val converted2 = ResolveInlineTables(conf).convert(table2) assert(converted2.schema.fields(0).nullable) } }
Example 17
Source File: view.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf object EliminateView extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // The child should have the same output attributes with the View operator, so we simply // remove the View operator. case View(_, output, child) => assert(output == child.output, s"The output of the child ${child.output.mkString("[", ",", "]")} is different from the " + s"view output ${output.mkString("[", ",", "]")}") child } }
Example 18
Source File: HiveTypeCoercionSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.hive.test.TestHive class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq( ("1", "1"), ("1.0", "CAST(1.0 AS DOUBLE)"), ("1L", "1L"), ("1S", "1S"), ("1Y", "1Y"), ("'1'", "'1'")) baseTypes.foreach { case (ni, si) => baseTypes.foreach { case (nj, sj) => createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1") } } val nullVal = "null" baseTypes.init.foreach { case (i, s) => createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $s else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $s end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.sparkPlan.collect { case e: ProjectExec => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => fail(s"unexpected cast $c") c } // Only one equality check var numEquals = 0 project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } assert(numEquals === 1) } }
Example 19
Source File: ResolveInlineTables.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{StructField, StructType} private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = { // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) } val attributes = StructType(fields).toAttributes assert(fields.size == table.names.size) val newRows: Seq[InternalRow] = table.rows.map { row => InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => val targetType = fields(ci).dataType try { if (e.dataType.sameType(targetType)) { e.eval() } else { Cast(e, targetType).eval() } } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") } }) } LocalRelation(attributes, newRows) } }