org.apache.spark.sql.catalyst.expressions.Alias Scala Examples
The following examples show how to use org.apache.spark.sql.catalyst.expressions.Alias.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: ComputeCurrentTimeSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 2
Source File: ExpandSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.IntegerType class ExpandSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder private def testExpand(f: SparkPlan => SparkPlan): Unit = { val input = (1 to 1000).map(Tuple1.apply) val projections = Seq.tabulate(2) { i => Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil } val attributes = projections.head.map(_.toAttribute) checkAnswer( input.toDF(), plan => Expand(projections, attributes, f(plan)), input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) ) } test("inheriting child row type") { val exprs = AttributeReference("a", IntegerType, false)() :: Nil val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") } test("expanding UnsafeRows") { testExpand(ConvertToUnsafe) } test("expanding SafeRows") { testExpand(identity) } }
Example 3
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.test.SharedSQLContext class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height // histogram usually contains hundreds of buckets. So we need to test // ApproxCountDistinctForIntervals with large number of endpoints // (the number of endpoints == the number of buckets + 1). test("test ApproxCountDistinctForIntervals with large number of endpoints") { val table = "approx_count_distinct_for_intervals_tbl" withTable(table) { (1 to 100000).toDF("col").createOrReplaceTempView(table) // percentiles of 0, 0.001, 0.002 ... 0.999, 1 val endpoints = (0 to 1000).map(_ * 100000 / 1000) // Since approx_count_distinct_for_intervals is not a public function, here we do // the computation by constructing logical plan. val relation = spark.table(table).logicalPlan val attr = relation.output.find(_.name == "col").get val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) val aggExpr = aggFunc.toAggregateExpression() val namedExpr = Alias(aggExpr, aggExpr.toString)() val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) .executedPlan.executeTake(1).head val ndvArray = ndvsRow.getArray(0).toLongArray() assert(endpoints.length == ndvArray.length + 1) // Each bucket has 100 distinct values. val expectedNdv = 100 for (i <- ndvArray.indices) { val ndv = ndvArray(i) val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") } } } }
Example 4
Source File: ComputeCurrentTimeSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 5
Source File: OptimizerStructuralIntegrityCheckerSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { object OptimizeRuleBreakSI extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(projectList, child) => val newAttr = UnresolvedAttribute("unresolvedAttr") Project(projectList ++ Seq(newAttr), child) } } object Optimize extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches } test("check for invalid plan after execution of rule") { val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze assert(analyzed.resolved) val message = intercept[TreeNodeException[LogicalPlan]] { Optimize.execute(analyzed) }.getMessage val ruleName = OptimizeRuleBreakSI.ruleName assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI")) assert(message.contains("the structural integrity of the plan is broken")) } }
Example 6
Source File: ProjectEstimation.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} object ProjectEstimation { import EstimationUtils._ def estimate(project: Project): Option[Statistics] = { if (rowCountsExist(project.child)) { val childStats = project.child.stats val inputAttrStats = childStats.attributeStats // Match alias with its child's column stat val aliasStats = project.expressions.collect { case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) => alias.toAttribute -> inputAttrStats(attr) } val outputAttrStats = getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) Some(childStats.copy( sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats), attributeStats = outputAttrStats)) } else { None } } }
Example 7
Source File: FramelessInternals.scala From frameless with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types._ import org.apache.spark.sql.types.ObjectType import scala.reflect.ClassTag object FramelessInternals { def objectTypeFor[A](implicit classTag: ClassTag[A]): ObjectType = ObjectType(classTag.runtimeClass) def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = { ds.toDF.queryExecution.analyzed.resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames.mkString(", ")})""") } } def expr(column: Column): Expression = column.expr def column(column: Column): Expression = column.expr def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution = ds.sparkSession.sessionState.executePlan(plan) def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = { val joined = executePlan(ds, plan) val leftOutput = joined.analyzed.output.take(leftPlan.output.length) val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length) Project(List( Alias(CreateStruct(leftOutput), "_1")(), Alias(CreateStruct(rightOutput), "_2")() ), joined.analyzed) } def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] = new Dataset(sqlContext, plan, encoder) def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = Dataset.ofRows(sparkSession, logicalPlan) // because org.apache.spark.sql.types.UserDefinedType is private[spark] type UserDefinedType[A >: Null] = org.apache.spark.sql.types.UserDefinedType[A] case class DisambiguateRight[T](tagged: Expression) extends Expression with NonSQLExpression { def eval(input: InternalRow): Any = tagged.eval(input) def nullable: Boolean = false def children: Seq[Expression] = tagged :: Nil def dataType: DataType = tagged.dataType protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ??? override def genCode(ctx: CodegenContext): ExprCode = tagged.genCode(ctx) } }
Example 8
Source File: ExchangeSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode1)) assert(mode2.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode3)) assert(mode3.compatibleWith(mode3)) } test("BroadcastExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchangeExec(hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchangeExec(hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3, sparkContext.sparkUser) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } test("ShuffleExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) val exchange1 = ShuffleExchange(part1, plan) val exchange2 = ShuffleExchange(part1, plan) val part2 = HashPartitioning(output, 2) val exchange3 = ShuffleExchange(part2, plan) val part3 = HashPartitioning(output ++ output, 2) val exchange4 = ShuffleExchange(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4, sparkContext.sparkUser) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(exchange5 sameResult exchange5) assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(!exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } }
Example 9
Source File: ComputeCurrentTimeSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 10
Source File: EncodeLongTest.scala From morpheus with Apache License 2.0 | 5 votes |
package org.opencypher.morpheus.impl.encoders import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.expressions.{Alias, GenericInternalRow} import org.apache.spark.sql.functions import org.apache.spark.sql.functions.typedLit import org.opencypher.morpheus.api.value.MorpheusElement._ import org.opencypher.morpheus.impl.expressions.EncodeLong import org.opencypher.morpheus.impl.expressions.EncodeLong._ import org.opencypher.morpheus.testing.MorpheusTestSuite import org.scalatestplus.scalacheck.Checkers class EncodeLongTest extends MorpheusTestSuite with Checkers { it("encodes longs correctly") { check((l: Long) => { val scala = l.encodeAsMorpheusId.toList val spark = typedLit[Long](l).encodeLongAsMorpheusId.expr.eval().asInstanceOf[Array[Byte]].toList scala === spark }, minSuccessful(1000)) } it("encoding/decoding is symmetric") { check((l: Long) => { val encoded = l.encodeAsMorpheusId val decoded = decodeLong(encoded) decoded === l }, minSuccessful(1000)) } it("scala version encodes longs correctly") { 0L.encodeAsMorpheusId.toList should equal(List(0.toByte)) } it("spark version encodes longs correctly") { typedLit[Long](0L).encodeLongAsMorpheusId.expr.eval().asInstanceOf[Array[Byte]].array.toList should equal(List(0.toByte)) } describe("Spark expression") { it("converts longs into byte arrays using expression interpreter") { check((l: Long) => { val positive = l & Long.MaxValue val inputRow = new GenericInternalRow(Array[Any](positive)) val encodeLong = EncodeLong(functions.lit(positive).expr) val interpreted = encodeLong.eval(inputRow).asInstanceOf[Array[Byte]] val decoded = decodeLong(interpreted) decoded === positive }, minSuccessful(1000)) } it("converts longs into byte arrays using expression code gen") { check((l: Long) => { val positive = l & Long.MaxValue val inputRow = new GenericInternalRow(Array[Any](positive)) val encodeLong = EncodeLong(functions.lit(positive).expr) val plan = GenerateMutableProjection.generate(Alias(encodeLong, s"Optimized($encodeLong)")() :: Nil) val codegen = plan(inputRow).get(0, encodeLong.dataType).asInstanceOf[Array[Byte]] val decoded = decodeLong(codegen) decoded === positive }, minSuccessful(1000)) } } }
Example 11
Source File: PostAggregate.scala From spark-druid-olap with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources.druid import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan import org.sparklinedata.druid._ class PostAggregate(val druidOpSchema : DruidOperatorSchema) { val dqb = druidOpSchema.dqb private def attrRef(dOpAttr : DruidOperatorAttribute) : AttributeReference = AttributeReference(dOpAttr.name, dOpAttr.dataType)(dOpAttr.exprId) lazy val groupExpressions = dqb.dimensions.map { d => attrRef(druidOpSchema.druidAttrMap(d.outputName)) } def namedGroupingExpressions = groupExpressions private def toSparkAgg(dAggSpec : AggregationSpec) : Option[AggregateFunction] = { val dOpAttr = druidOpSchema.druidAttrMap(dAggSpec.name) dAggSpec match { case FunctionAggregationSpec("count", nm, _) => Some(Sum(attrRef(dOpAttr))) case FunctionAggregationSpec("longSum", nm, _) => Some(Sum(attrRef(dOpAttr))) case FunctionAggregationSpec("doubleSum", nm, _) => Some(Sum(attrRef(dOpAttr))) case FunctionAggregationSpec("longMin", nm, _) => Some(Min(attrRef(dOpAttr))) case FunctionAggregationSpec("doubleMin", nm, _) => Some(Min(attrRef(dOpAttr))) case FunctionAggregationSpec("longMax", nm, _) => Some(Max(attrRef(dOpAttr))) case FunctionAggregationSpec("doubleMax", nm, _) => Some(Max(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("MIN") => Some(Min(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("MAX") => Some(Max(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("SUM") => Some(Sum(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("COUNT") => Some(Sum(attrRef(dOpAttr))) case _ => None } } lazy val aggregatesO : Option[List[NamedExpression]] = Utils.sequence( dqb.aggregations.map { da => val dOpAttr = druidOpSchema.druidAttrMap(da.name) toSparkAgg(da).map { aggFunc => Alias(AggregateExpression(aggFunc, Complete, false), dOpAttr.name)(dOpAttr.exprId) } }) def canBeExecutedInHistorical : Boolean = dqb.canPushToHistorical && aggregatesO.isDefined lazy val resultExpressions = groupExpressions ++ aggregatesO.get lazy val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression => agg } }.distinct lazy val aggregateFunctionToAttribute = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute (aggregateFunction, agg.isDistinct) -> attribute }.toMap lazy val rewrittenResultExpressions = resultExpressions.map { expr => expr.transformDown { case aE@AggregateExpression(aggregateFunction, _, isDistinct, _) => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: // aggregateFunctionToAttribute(aggregateFunction, isDistinct) aE.resultAttribute case expression => expression }.asInstanceOf[NamedExpression] } def aggOp(child : SparkPlan) : Seq[SparkPlan] = { org.apache.spark.sql.execution.aggregate.AggUtils.planAggregateWithoutPartial( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) } }
Example 12
Source File: RangerSparkMaskingExtensionTest.scala From spark-ranger with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.RangerSparkTestUtils._ import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical.{Project, RangerSparkMasking} import org.scalatest.FunSuite class RangerSparkMaskingExtensionTest extends FunSuite { private val spark = TestHive.sparkSession test("data masking for bob show last 4") { val extension = RangerSparkMaskingExtension(spark) val plan = spark.sql("select * from src").queryExecution.optimizedPlan println(plan) withUser("bob") { val newPlan = extension.apply(plan) assert(newPlan.isInstanceOf[Project]) val project = newPlan.asInstanceOf[Project] val key = project.projectList.head assert(key.name === "key", "no affect on un masking attribute") val value = project.projectList.tail assert(value.head.name === "value", "attibute name should be unchanged") assert(value.head.asInstanceOf[Alias].child.sql === "mask_show_last_n(`value`, 4, 'x', 'x', 'x', -1, '1')") } withUser("alice") { val newPlan = extension.apply(plan) assert(newPlan === RangerSparkMasking(plan)) } } }
Example 13
Source File: CarbonUDFTransformRule.scala From carbondata with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.optimizer import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, PredicateHelper, ScalaUDF} import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.types.StringType import org.apache.carbondata.core.constants.CarbonCommonConstants class CarbonUDFTransformRule extends Rule[LogicalPlan] with PredicateHelper { override def apply(plan: LogicalPlan): LogicalPlan = { pushDownUDFToJoinLeftRelation(plan) } private def pushDownUDFToJoinLeftRelation(plan: LogicalPlan): LogicalPlan = { val output = plan.transform { case proj@Project(cols, Join( left, right, jointype: org.apache.spark.sql.catalyst.plans.JoinType, condition)) => var projectionToBeAdded: Seq[org.apache.spark.sql.catalyst.expressions.Alias] = Seq.empty var udfExists = false val newCols = cols.map { case a@Alias(s: ScalaUDF, name) if name.equalsIgnoreCase(CarbonCommonConstants.POSITION_ID) || name.equalsIgnoreCase(CarbonCommonConstants.CARBON_IMPLICIT_COLUMN_TUPLEID) => udfExists = true projectionToBeAdded :+= a AttributeReference(name, StringType, nullable = true)().withExprId(a.exprId) case other => other } if (udfExists) { val newLeft = left match { case Project(columns, logicalPlan) => Project(columns ++ projectionToBeAdded, logicalPlan) case filter: Filter => Project(filter.output ++ projectionToBeAdded, filter) case relation: LogicalRelation => Project(relation.output ++ projectionToBeAdded, relation) case other => other } Project(newCols, Join(newLeft, right, jointype, condition)) } else { proj } case other => other } output } }
Example 14
Source File: ExpressionHelper.scala From carbondata with Apache License 2.0 | 5 votes |
package org.apache.carbondata.mv.plans.modular import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId, Expression, NamedExpression} import org.apache.spark.sql.types.{DataType, Metadata} object ExpressionHelper { def createReference( name: String, dataType: DataType, nullable: Boolean, metadata: Metadata, exprId: ExprId, qualifier: Option[String], attrRef : NamedExpression = null): AttributeReference = { val qf = if (qualifier.nonEmpty) Seq(qualifier.get) else Seq.empty AttributeReference(name, dataType, nullable, metadata)(exprId, qf) } def createAlias( child: Expression, name: String, exprId: ExprId, qualifier: Option[String]) : Alias = { val qf = if (qualifier.nonEmpty) Seq(qualifier.get) else Seq.empty Alias(child, name)(exprId, qf, None) } def getTheLastQualifier(reference: AttributeReference): String = { reference.qualifier.reverse.head } }
Example 15
Source File: ExpressionHelper.scala From carbondata with Apache License 2.0 | 5 votes |
package org.apache.carbondata.mv.plans.modular import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId, Expression, NamedExpression} import org.apache.spark.sql.types.{DataType, Metadata} object ExpressionHelper { def createReference( name: String, dataType: DataType, nullable: Boolean, metadata: Metadata, exprId: ExprId, qualifier: Option[String], attrRef : NamedExpression = null): AttributeReference = { AttributeReference(name, dataType, nullable, metadata)(exprId, qualifier) } def createAlias( child: Expression, name: String, exprId: ExprId = NamedExpression.newExprId, qualifier: Option[String] = None, explicitMetadata: Option[Metadata] = None, namedExpr : Option[NamedExpression] = None ) : Alias = { Alias(child, name)(exprId, qualifier, explicitMetadata) } def getTheLastQualifier(reference: AttributeReference): String = { reference.qualifier.head } }
Example 16
Source File: ExchangeSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode1)) assert(mode2.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode3)) assert(mode3.compatibleWith(mode3)) } test("BroadcastExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchangeExec(hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchangeExec(hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } test("ShuffleExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) val exchange1 = ShuffleExchange(part1, plan) val exchange2 = ShuffleExchange(part1, plan) val part2 = HashPartitioning(output, 2) val exchange3 = ShuffleExchange(part2, plan) val part3 = HashPartitioning(output ++ output, 2) val exchange4 = ShuffleExchange(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(exchange5 sameResult exchange5) assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(!exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } }
Example 17
Source File: ComputeCurrentTimeSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 18
Source File: ResolveCountDistinctStarSuite.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.scalatest.FunSuite import org.scalatest.Inside._ import org.scalatest.mock.MockitoSugar import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import scala.collection.mutable.ArrayBuffer class ResolveCountDistinctStarSuite extends FunSuite with MockitoSugar { val persons = new LogicalRelation(new BaseRelation { override def sqlContext: SQLContext = mock[SQLContext] override def schema: StructType = StructType(Seq( StructField("age", IntegerType), StructField("name", StringType) )) }) test("Count distinct star is resolved correctly") { val projection = persons.select(UnresolvedAlias( AggregateExpression(Count(UnresolvedStar(None) :: Nil), Complete, true))) val stillNotCompletelyResolvedAggregate = SimpleAnalyzer.execute(projection) val resolvedAggregate = ResolveCountDistinctStar(SimpleAnalyzer) .apply(stillNotCompletelyResolvedAggregate) inside(resolvedAggregate) { case Aggregate(Nil, ArrayBuffer(Alias(AggregateExpression(Count(expressions), Complete, true), _)), _) => assert(expressions.collect { case a:AttributeReference => a.name }.toSet == Set("name", "age")) } assert(resolvedAggregate.resolved) } }
Example 19
Source File: RemoveNestedAliasesSuite.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import com.sap.spark.PlanTest import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar class RemoveNestedAliasesSuite extends FunSuite with MockitoSugar with PlanTest { val br1 = new BaseRelation { override def sqlContext: SQLContext = mock[SQLContext] override def schema: StructType = StructType(Seq( StructField("name", StringType), StructField("age", IntegerType) )) } val lr1 = LogicalRelation(br1) val nameAtt = lr1.output.find(_.name == "name").get val ageAtt = lr1.output.find(_.name == "age").get test("Replace alias into aliases") { val avgExpr = avg(ageAtt) val avgAlias = avgExpr as 'avgAlias val aliasAlias = avgAlias as 'aliasAlias val aliasAliasAlias = aliasAlias as 'aliasAliasAlias val copiedAlias = Alias(avgExpr, aliasAlias.name)( exprId = aliasAlias.exprId ) val copiedAlias2 = Alias(avgExpr, aliasAliasAlias.name)( exprId = aliasAliasAlias.exprId ) assertResult( lr1.groupBy(avgAlias.toAttribute)(avgAlias) )(RemoveNestedAliases(lr1.groupBy(avgAlias.toAttribute)(avgAlias))) assertResult( lr1.groupBy(copiedAlias.toAttribute)(copiedAlias) )(RemoveNestedAliases(lr1.groupBy(aliasAlias.toAttribute)(aliasAlias))) assertResult( lr1.groupBy(copiedAlias2.toAttribute)(copiedAlias2) )(RemoveNestedAliases(lr1.groupBy(aliasAliasAlias.toAttribute)(aliasAliasAlias))) } test("Replace alias into expressions") { val ageAlias = ageAtt as 'ageAlias val avgExpr = avg(ageAlias) as 'avgAlias val correctedAvgExpr = avg(ageAtt) as 'avgAlias comparePlans( lr1.groupBy(correctedAvgExpr.toAttribute)(correctedAvgExpr), RemoveNestedAliases(lr1.groupBy(avgExpr.toAttribute)(avgExpr)) ) } }
Example 20
Source File: UseAliasesForFunctionsInGroupings.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Subquery} import org.apache.spark.sql.catalyst.rules.Rule object UseAliasesForFunctionsInGroupings extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case agg@Aggregate(groupingExpressions, aggregateExpressions, child) => val fixedGroupingExpressions = groupingExpressions.map({ case e: AttributeReference => e case e => val aliasOpt = aggregateExpressions.find({ case Alias(aliasChild, aliasName) => aliasChild == e case _ => false }) aliasOpt match { case Some(alias) => alias.toAttribute case None => sys.error(s"Cannot resolve Alias for $e") } }) agg.copy(groupingExpressions = fixedGroupingExpressions) } }
Example 21
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.test.SharedSQLContext class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height // histogram usually contains hundreds of buckets. So we need to test // ApproxCountDistinctForIntervals with large number of endpoints // (the number of endpoints == the number of buckets + 1). test("test ApproxCountDistinctForIntervals with large number of endpoints") { val table = "approx_count_distinct_for_intervals_tbl" withTable(table) { (1 to 100000).toDF("col").createOrReplaceTempView(table) // percentiles of 0, 0.001, 0.002 ... 0.999, 1 val endpoints = (0 to 1000).map(_ * 100000 / 1000) // Since approx_count_distinct_for_intervals is not a public function, here we do // the computation by constructing logical plan. val relation = spark.table(table).logicalPlan val attr = relation.output.find(_.name == "col").get val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) val aggExpr = aggFunc.toAggregateExpression() val namedExpr = Alias(aggExpr, aggExpr.toString)() val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) .executedPlan.executeTake(1).head val ndvArray = ndvsRow.getArray(0).toLongArray() assert(endpoints.length == ndvArray.length + 1) // Each bucket has 100 distinct values. val expectedNdv = 100 for (i <- ndvArray.indices) { val ndv = ndvArray(i) val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") } } } }
Example 22
Source File: ComputeCurrentTimeSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 23
Source File: OptimizerStructuralIntegrityCheckerSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { object OptimizeRuleBreakSI extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(projectList, child) => val newAttr = UnresolvedAttribute("unresolvedAttr") Project(projectList ++ Seq(newAttr), child) } } object Optimize extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches } test("check for invalid plan after execution of rule") { val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze assert(analyzed.resolved) val message = intercept[TreeNodeException[LogicalPlan]] { Optimize.execute(analyzed) }.getMessage val ruleName = OptimizeRuleBreakSI.ruleName assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI")) assert(message.contains("the structural integrity of the plan is broken")) } }
Example 24
Source File: LookupFunctionsSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import java.net.URI import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf class LookupFunctionsSuite extends PlanTest { test("SPARK-23486: the functionExists for the Persistent function check") { val externalCatalog = new CustomInMemoryCatalog val conf = new SQLConf() val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf) val analyzer = { catalog.createDatabase( CatalogDatabase("default", "", new URI("loc"), Map.empty), ignoreIfExists = false) new Analyzer(catalog, conf) } def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false) val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) val plan = Project( Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(), Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(), Alias(unresolvedRegisteredFunc, "call5")()), table("TaBlE")) analyzer.LookupFunctions.apply(plan) assert(externalCatalog.getFunctionExistsCalledTimes == 1) assert(analyzer.LookupFunctions.normalizeFuncName (unresolvedPersistentFunc.name).database == Some("default")) } test("SPARK-23486: the functionExists for the Registered function check") { val externalCatalog = new InMemoryCatalog val conf = new SQLConf() val customerFunctionReg = new CustomerFunctionRegistry val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf) val analyzer = { catalog.createDatabase( CatalogDatabase("default", "", new URI("loc"), Map.empty), ignoreIfExists = false) new Analyzer(catalog, conf) } def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) val plan = Project( Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()), table("TaBlE")) analyzer.LookupFunctions.apply(plan) assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) assert(analyzer.LookupFunctions.normalizeFuncName (unresolvedRegisteredFunc.name).database == Some("default")) } } class CustomerFunctionRegistry extends SimpleFunctionRegistry { private var isRegisteredFunctionCalledTimes: Int = 0; override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized { isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1 true } def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes } class CustomInMemoryCatalog extends InMemoryCatalog { private var functionExistsCalledTimes: Int = 0 override def functionExists(db: String, funcName: String): Boolean = synchronized { functionExistsCalledTimes = functionExistsCalledTimes + 1 true } def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes }
Example 25
Source File: LogicalPlanSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 private val function: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Project => invocationCount += 1 p } private val testRelation = LocalRelation() test("transformUp runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) plan transformUp function assert(invocationCount === 1) invocationCount = 0 plan transformDown function assert(invocationCount === 1) } test("transformUp runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) plan transformUp function assert(invocationCount === 2) invocationCount = 0 plan transformDown function assert(invocationCount === 2) } test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val incrementalRelation = LocalRelation( Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true) case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output } require(relation.isStreaming === false) require(incrementalRelation.isStreaming === true) assert(TestBinaryRelation(relation, relation).isStreaming === false) assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true) assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } test("transformExpressions works with a Stream") { val id1 = NamedExpression.newExprId val id2 = NamedExpression.newExprId val plan = Project(Stream( Alias(Literal(1), "a")(exprId = id1), Alias(Literal(2), "b")(exprId = id2)), OneRowRelation()) val result = plan.transformExpressions { case Literal(v: Int, IntegerType) if v != 1 => Literal(v + 1, IntegerType) } val expected = Project(Stream( Alias(Literal(1), "a")(exprId = id1), Alias(Literal(3), "b")(exprId = id2)), OneRowRelation()) assert(result.sameResult(expected)) } }
Example 26
Source File: ResolveTableValuedFunctions.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, "numPartitions" -> IntegerType) { case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => Range(start, end, step, Some(numPartitions)) }) ) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => // The whole resolution is somewhat difficult to understand here due to too much abstractions. // We should probably rewrite the following at some point. Reynold was just here to improve // error messages and didn't have time to do a proper rewrite. val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => def failAnalysis(): Nothing = { val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") u.failAnalysis( s"""error: table-valued function ${u.functionName} with alternatives: |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} |cannot be applied to: ($argTypes)""".stripMargin) } val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { case Some(casted) => try { Some(resolver(casted.map(_.eval()))) } catch { case e: AnalysisException => failAnalysis() } case _ => None } } resolved.headOption.getOrElse { failAnalysis() } case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } // If alias names assigned, add `Project` with the aliases if (u.outputNames.nonEmpty) { val outputAttrs = resolvedFunc.output // Checks if the number of the aliases is equal to expected one if (u.outputNames.size != outputAttrs.size) { u.failAnalysis(s"Number of given aliases does not match number of output columns. " + s"Function name: ${u.functionName}; number of aliases: " + s"${u.outputNames.size}; number of output columns: ${outputAttrs.size}.") } val aliases = outputAttrs.zip(u.outputNames).map { case (attr, name) => Alias(attr, name)() } Project(aliases, resolvedFunc) } else { resolvedFunc } } }
Example 27
Source File: view.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf object EliminateView extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // The child should have the same output attributes with the View operator, so we simply // remove the View operator. case View(_, output, child) => assert(output == child.output, s"The output of the child ${child.output.mkString("[", ",", "]")} is different from the " + s"view output ${output.mkString("[", ",", "]")}") child } }
Example 28
Source File: ProjectEstimation.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} object ProjectEstimation { import EstimationUtils._ def estimate(project: Project): Option[Statistics] = { if (rowCountsExist(project.child)) { val childStats = project.child.stats val inputAttrStats = childStats.attributeStats // Match alias with its child's column stat val aliasStats = project.expressions.collect { case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) => alias.toAttribute -> inputAttrStats(attr) } val outputAttrStats = getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) Some(childStats.copy( sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats), attributeStats = outputAttrStats)) } else { None } } }
Example 29
Source File: SparkWrapper.scala From tispark with Apache License 2.0 | 5 votes |
package com.pingcap.tispark import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types.{DataType, Metadata} object SparkWrapper { def getVersion: String = { "SparkWrapper-2.4" } def newSubqueryAlias(identifier: String, child: LogicalPlan): SubqueryAlias = { SubqueryAlias(identifier, child) } def newAlias(child: Expression, name: String): Alias = { Alias(child, name)() } def newAttributeReference( name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): AttributeReference = { AttributeReference(name, dataType, nullable, metadata)() } def callSessionCatalogCreateTable( obj: SessionCatalog, tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { obj.createTable(tableDefinition, ignoreIfExists) } }
Example 30
Source File: SparkWrapper.scala From tispark with Apache License 2.0 | 5 votes |
package com.pingcap.tispark import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types.{DataType, Metadata} object SparkWrapper { def getVersion: String = { "SparkWrapper-2.3" } def newSubqueryAlias(identifier: String, child: LogicalPlan): SubqueryAlias = { SubqueryAlias(identifier, child) } def newAlias(child: Expression, name: String): Alias = { Alias(child, name)() } def newAttributeReference( name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): AttributeReference = { AttributeReference(name, dataType, nullable, metadata)() } def callSessionCatalogCreateTable( obj: SessionCatalog, tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { obj.createTable(tableDefinition, ignoreIfExists) } }
Example 31
Source File: ExchangeSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode1)) assert(mode2.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode3)) assert(mode3.compatibleWith(mode3)) } test("BroadcastExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchangeExec(hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchangeExec(hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } test("ShuffleExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) val exchange1 = ShuffleExchange(part1, plan) val exchange2 = ShuffleExchange(part1, plan) val part2 = HashPartitioning(output, 2) val exchange3 = ShuffleExchange(part2, plan) val part3 = HashPartitioning(output ++ output, 2) val exchange4 = ShuffleExchange(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(exchange5 sameResult exchange5) assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(!exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } }