org.apache.spark.sql.catalyst.expressions.Literal Scala Examples
The following examples show how to use org.apache.spark.sql.catalyst.expressions.Literal.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: MyUDF.scala From spark-tools with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.TimestampType object MyUDF { private def myTimestampCast(xs: Seq[Expression]): Expression = { val expSource = xs.head expSource.dataType match { case LongType => new Column(expSource).divide(Literal(1000)).cast(TimestampType).expr case TimestampType => expSource } } def register(sparkSession: SparkSession): Unit = sparkSession.sessionState.functionRegistry .registerFunction(FunctionIdentifier("toTs",None), myTimestampCast) }
Example 2
Source File: MyUDF.scala From spark-tools with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.TimestampType object MyUDF { private def myTimestampCast(xs: Seq[Expression]): Expression = { val expSource = xs.head expSource.dataType match { case LongType => new Column(expSource).divide(Literal(1000)).cast(TimestampType).expr case TimestampType => expSource } } def register(sparkSession: SparkSession): Unit = sparkSession.sessionState.functionRegistry .registerFunction("toTs", myTimestampCast) }
Example 3
Source File: LocalRelation.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) def apply(output1: StructField, output: StructField*): LocalRelation = { new LocalRelation(StructType(output1 +: output).toAttributes) } def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } } case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.") override final def newInstance(): this.type = { LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] } override protected def stringArgs: Iterator[Any] = { if (data.isEmpty) { Iterator("<empty>", output) } else { Iterator(output) } } override def sameResult(plan: LogicalPlan): Boolean = { plan.canonicalized match { case LocalRelation(otherOutput, otherData) => otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data case _ => false } } override lazy val statistics = Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) val types = output.map(_.dataType) val rows = data.map { row => val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } cells.mkString("(", ", ", ")") } "VALUES " + rows.mkString(", ") + " AS " + inlineTableName + output.map(_.name).mkString("(", ", ", ")") } }
Example 4
Source File: SubstituteUnresolvedOrdinals.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.types.IntegerType class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other } withOrigin(s.origin)(s.copy(order = newOrders)) case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) => val newGroups = a.groupingExpressions.map { case ordinal @ Literal(index: Int, IntegerType) => withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) case other => other } withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) } }
Example 5
Source File: RuleExecutorSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } val message = intercept[TreeNodeException[LogicalPlan]] { ToFixedPoint.execute(Literal(100)) }.getMessage assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } }
Example 6
Source File: SubstituteUnresolvedOrdinalsSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.SimpleCatalystConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) private lazy val a = testRelation2.output(0) private lazy val b = testRelation2.output(1) test("unresolved ordinal should not be unresolved") { // Expression OrderByOrdinal is unresolved. assert(!UnresolvedOrdinal(0).resolved) } test("order by ordinal") { // Tests order by ordinal, apply single rule. val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan), testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) // Tests order by ordinal, do full analysis checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) // order by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } test("group by ordinal") { // Tests group by ordinal, apply single rule. val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan2), testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) // Tests group by ordinal, do full analysis checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) // group by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } }
Example 7
Source File: ResolveInlineTablesSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types.{LongType, NullType} class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) assert(ResolveInlineTables(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted = ResolveInlineTables.convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) assert(converted.data(0).getLong(0) == 1L) assert(converted.data(1).getLong(0) == 2L) } test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted1 = ResolveInlineTables.convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val converted2 = ResolveInlineTables.convert(table2) assert(converted2.schema.fields(0).nullable) } }
Example 8
Source File: PartitioningSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} class PartitioningSuite extends SparkFunSuite { test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { val expressions = Seq(Literal(2), Literal(3)) // Consider two HashPartitionings that have the same _set_ of hash expressions but which are // created with different orderings of those expressions: val partitioningA = HashPartitioning(expressions, 100) val partitioningB = HashPartitioning(expressions.reverse, 100) // These partitionings are not considered equal: assert(partitioningA != partitioningB) // However, they both satisfy the same clustered distribution: val distribution = ClusteredDistribution(expressions) assert(partitioningA.satisfies(distribution)) assert(partitioningB.satisfies(distribution)) // These partitionings compute different hashcodes for the same input row: def computeHashCode(partitioning: HashPartitioning): Int = { val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) hashExprProj.apply(InternalRow.empty).hashCode() } assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) // Thus, these partitionings are incompatible: assert(!partitioningA.compatibleWith(partitioningB)) assert(!partitioningB.compatibleWith(partitioningA)) assert(!partitioningA.guarantees(partitioningB)) assert(!partitioningB.guarantees(partitioningA)) // Just to be sure that we haven't cheated by having these methods always return false, // check that identical partitionings are still compatible with and guarantee each other: assert(partitioningA === partitioningA) assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } }
Example 9
Source File: RewriteDistinctAggregatesSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) val nullInt = Literal(null, IntegerType) val nullString = Literal(null, StringType) val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { case Aggregate(_, _, Aggregate(_, _, _: Expand)) => case _ => fail(s"Plan is not rewritten:\n$rewrite") } test("single distinct group") { val input = testRelation .groupBy('a)(countDistinct('e)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), max('b).as('agg2)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with non-partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), CollectSet('b).toAggregateExpression().as('agg2)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with partial aggregates") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with non-partial aggregates") { val input = testRelation .groupBy('a)( countDistinct('b, 'c), countDistinct('d), CollectSet('b).toAggregateExpression()) .analyze checkRewrite(RewriteDistinctAggregates(input)) } }
Example 10
Source File: ComputeCurrentTimeSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 11
Source File: AggregateOptimizeSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), FoldablePropagation, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("remove literals in grouping expression") { val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("do not remove all grouping expressions if they are all literals") { val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) comparePlans(optimized, correctAnswer) } test("Remove aliased literals") { val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("remove repetition in grouping expression") { val input = LocalRelation('a.int, 'b.int, 'c.int) val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } }
Example 12
Source File: subquery.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]()) val sameResult = sameSchema.find(_.sameResult(sub.plan)) if (sameResult.isDefined) { sub.withNewPlan(sameResult.get) } else { sameSchema += sub.plan sub } } } }
Example 13
Source File: ExchangeSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode1)) assert(mode2.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode3)) assert(mode3.compatibleWith(mode3)) } test("BroadcastExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchangeExec(hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchangeExec(hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } test("ShuffleExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) val exchange1 = ShuffleExchange(part1, plan) val exchange2 = ShuffleExchange(part1, plan) val part2 = HashPartitioning(output, 2) val exchange3 = ShuffleExchange(part2, plan) val part3 = HashPartitioning(output ++ output, 2) val exchange4 = ShuffleExchange(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(exchange5 sameResult exchange5) assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(!exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } }
Example 14
Source File: TakeOrderedAndProjectSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) } private def generateRandomInputData(): DataFrame = { val schema = new StructType() .add("a", IntegerType, nullable = false) .add("b", IntegerType, nullable = false) val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) val limit = 250 val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, SortExec(sortOrder, true, input))), sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter( TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, ProjectExec(Seq(input.output.last), SortExec(sortOrder, true, input)))), sortAnswers = false) } } }
Example 15
Source File: DeltaPushFilter.scala From connectors with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.delta import scala.collection.immutable.HashSet import scala.collection.JavaConverters._ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, SerializationUtilities} import org.apache.hadoop.hive.ql.lib._ import org.apache.hadoop.hive.ql.parse.SemanticException import org.apache.hadoop.hive.ql.plan.{ExprNodeColumnDesc, ExprNodeConstantDesc, ExprNodeGenericFuncDesc} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, InSet, LessThan, LessThanOrEqual, Like, Literal, Not} object DeltaPushFilter extends Logging { lazy val supportedPushDownUDFs = Array( "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualNS", "org.apache.hadoop.hive.ql.udf.UDFLike", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn" ) def partitionFilterConverter(hiveFilterExprSeriablized: String): Seq[Expression] = { if (hiveFilterExprSeriablized != null) { val filterExpr = SerializationUtilities.deserializeExpression(hiveFilterExprSeriablized) val opRules = new java.util.LinkedHashMap[Rule, NodeProcessor]() val nodeProcessor = new NodeProcessor() { @throws[SemanticException] def process(nd: Node, stack: java.util.Stack[Node], procCtx: NodeProcessorCtx, nodeOutputs: Object*): Object = { nd match { case e: ExprNodeGenericFuncDesc if FunctionRegistry.isOpAnd(e) => nodeOutputs.map(_.asInstanceOf[Expression]).reduce(And) case e: ExprNodeGenericFuncDesc => val (columnDesc, constantDesc) = if (nd.getChildren.get(0).isInstanceOf[ExprNodeColumnDesc]) { (nd.getChildren.get(0), nd.getChildren.get(1)) } else { (nd.getChildren.get(1), nd.getChildren.get(0)) } val columnAttr = UnresolvedAttribute( columnDesc.asInstanceOf[ExprNodeColumnDesc].getColumn) val constantVal = Literal(constantDesc.asInstanceOf[ExprNodeConstantDesc].getValue) nd.asInstanceOf[ExprNodeGenericFuncDesc].getGenericUDF match { case f: GenericUDFOPNotEqualNS => Not(EqualNullSafe(columnAttr, constantVal)) case f: GenericUDFOPNotEqual => Not(EqualTo(columnAttr, constantVal)) case f: GenericUDFOPEqualNS => EqualNullSafe(columnAttr, constantVal) case f: GenericUDFOPEqual => EqualTo(columnAttr, constantVal) case f: GenericUDFOPGreaterThan => GreaterThan(columnAttr, constantVal) case f: GenericUDFOPEqualOrGreaterThan => GreaterThanOrEqual(columnAttr, constantVal) case f: GenericUDFOPLessThan => LessThan(columnAttr, constantVal) case f: GenericUDFOPEqualOrLessThan => LessThanOrEqual(columnAttr, constantVal) case f: GenericUDFBridge if f.getUdfName.equals("like") => Like(columnAttr, constantVal) case f: GenericUDFIn => val inConstantVals = nd.getChildren.asScala .filter(_.isInstanceOf[ExprNodeConstantDesc]) .map(_.asInstanceOf[ExprNodeConstantDesc].getValue) .map(Literal(_)).toSet InSet(columnAttr, HashSet() ++ inConstantVals) case _ => throw new RuntimeException(s"Unsupported func(${nd.getName}) " + s"which can not be pushed down to delta") } case _ => null } } } val disp = new DefaultRuleDispatcher(nodeProcessor, opRules, null) val ogw = new DefaultGraphWalker(disp) val topNodes = new java.util.ArrayList[Node]() topNodes.add(filterExpr) val nodeOutput = new java.util.HashMap[Node, Object]() try { ogw.startWalking(topNodes, nodeOutput) } catch { case ex: Exception => throw new RuntimeException(ex) } logInfo(s"converted partition filter expr:" + s"${nodeOutput.get(filterExpr).asInstanceOf[Expression].toJSON}") Seq(nodeOutput.get(filterExpr).asInstanceOf[Expression]) } else Seq.empty[org.apache.spark.sql.catalyst.expressions.Expression] } }
Example 16
Source File: ValueInterval.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ def intersect(r1: ValueInterval, r2: ValueInterval, dt: DataType): (Option[Any], Option[Any]) = { (r1, r2) match { case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => // binary/string types don't support intersecting. (None, None) case (n1: NumericValueInterval, n2: NumericValueInterval) => // Choose the maximum of two min values, and the minimum of two max values. val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMax = if (n1.max <= n2.max) n1.max else n2.max (Some(EstimationUtils.fromDouble(newMin, dt)), Some(EstimationUtils.fromDouble(newMax, dt))) case _ => throw new UnsupportedOperationException(s"Not supported pair: $r1, $r2 at intersect()") } } }
Example 17
Source File: LocalRelation.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) def apply(output1: StructField, output: StructField*): LocalRelation = { new LocalRelation(StructType(output1 +: output).toAttributes) } def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } } override final def newInstance(): this.type = { LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type] } override protected def stringArgs: Iterator[Any] = { if (data.isEmpty) { Iterator("<empty>", output) } else { Iterator(output) } } override def computeStats(): Statistics = Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) val types = output.map(_.dataType) val rows = data.map { row => val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } cells.mkString("(", ", ", ")") } "VALUES " + rows.mkString(", ") + " AS " + inlineTableName + output.map(_.name).mkString("(", ", ", ")") } }
Example 18
Source File: SubstituteUnresolvedOrdinals.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other } withOrigin(s.origin)(s.copy(order = newOrders)) case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) => val newGroups = a.groupingExpressions.map { case ordinal @ Literal(index: Int, IntegerType) => withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) case other => other } withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) } }
Example 19
Source File: LogicalPlanSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 private val function: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Project => invocationCount += 1 p } private val testRelation = LocalRelation() test("transformUp runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) plan transformUp function assert(invocationCount === 1) invocationCount = 0 plan transformDown function assert(invocationCount === 1) } test("transformUp runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) plan transformUp function assert(invocationCount === 2) invocationCount = 0 plan transformDown function assert(invocationCount === 2) } test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val incrementalRelation = LocalRelation( Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true) case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output } require(relation.isStreaming === false) require(incrementalRelation.isStreaming === true) assert(TestBinaryRelation(relation, relation).isStreaming === false) assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true) assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } test("transformExpressions works with a Stream") { val id1 = NamedExpression.newExprId val id2 = NamedExpression.newExprId val plan = Project(Stream( Alias(Literal(1), "a")(exprId = id1), Alias(Literal(2), "b")(exprId = id2)), OneRowRelation()) val result = plan.transformExpressions { case Literal(v: Int, IntegerType) if v != 1 => Literal(v + 1, IntegerType) } val expected = Project(Stream( Alias(Literal(1), "a")(exprId = id1), Alias(Literal(3), "b")(exprId = id2)), OneRowRelation()) assert(result.sameResult(expected)) } }
Example 20
Source File: QueryPlanSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.types.IntegerType class QueryPlanSuite extends SparkFunSuite { test("origin remains the same after mapExpressions (SPARK-23823)") { CurrentOrigin.setPosition(0, 0) val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId) val query = plans.DslLogicalPlan(plans.table("table")).select(column) CurrentOrigin.reset() val mappedQuery = query mapExpressions { case _: Expression => Literal(1) } val mappedOrigin = mappedQuery.expressions.apply(0).origin assert(mappedOrigin == Origin.apply(Some(0), Some(0))) } }
Example 21
Source File: RuleExecutorSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } val message = intercept[TreeNodeException[LogicalPlan]] { ToFixedPoint.execute(Literal(100)) }.getMessage assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } test("structural integrity checker") { object WithSIChecker extends RuleExecutor[Expression] { override protected def isPlanIntegral(expr: Expression): Boolean = expr match { case IntegerLiteral(_) => true case _ => false } val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(WithSIChecker.execute(Literal(10)) === Literal(9)) val message = intercept[TreeNodeException[LogicalPlan]] { WithSIChecker.execute(Literal(10.1)) }.getMessage assert(message.contains("the structural integrity of the plan is broken")) } }
Example 22
Source File: SubstituteUnresolvedOrdinalsSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.internal.SQLConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val a = testRelation2.output(0) private lazy val b = testRelation2.output(1) test("unresolved ordinal should not be unresolved") { // Expression OrderByOrdinal is unresolved. assert(!UnresolvedOrdinal(0).resolved) } test("order by ordinal") { // Tests order by ordinal, apply single rule. val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan), testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) // Tests order by ordinal, do full analysis checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) // order by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.ORDER_BY_ORDINAL -> false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } test("group by ordinal") { // Tests group by ordinal, apply single rule. val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan2), testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) // Tests group by ordinal, do full analysis checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) // group by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.GROUP_BY_ORDINAL -> false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } }
Example 23
Source File: ResolveInlineTablesSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) assert(ResolveInlineTables(conf)(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted = ResolveInlineTables(conf).convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) assert(converted.data(0).getLong(0) == 1L) assert(converted.data(1).getLong(0) == 2L) } test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone(conf).apply(table) val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) assert(data.size == 1) assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted1 = ResolveInlineTables(conf).convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val converted2 = ResolveInlineTables(conf).convert(table2) assert(converted2.schema.fields(0).nullable) } }
Example 24
Source File: ConvertToLocalRelationSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class ConvertToLocalRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Nil } test("Project on LocalRelation should be turned into a single LocalRelation") { val testRelation = LocalRelation( LocalRelation('a.int, 'b.int).output, InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) val correctAnswer = LocalRelation( LocalRelation('a1.int, 'b1.int).output, InternalRow(1, 3) :: InternalRow(4, 6) :: Nil) val projectOnLocal = testRelation.select( UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) val optimized = Optimize.execute(projectOnLocal.analyze) comparePlans(optimized, correctAnswer) } test("Filter on LocalRelation should be turned into a single LocalRelation") { val testRelation = LocalRelation( LocalRelation('a.int, 'b.int).output, InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) val correctAnswer = LocalRelation( LocalRelation('a1.int, 'b1.int).output, InternalRow(1, 3) :: Nil) val filterAndProjectOnLocal = testRelation .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6))) val optimized = Optimize.execute(filterAndProjectOnLocal.analyze) comparePlans(optimized, correctAnswer) } }
Example 25
Source File: OptimizerStructuralIntegrityCheckerSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { object OptimizeRuleBreakSI extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(projectList, child) => val newAttr = UnresolvedAttribute("unresolvedAttr") Project(projectList ++ Seq(newAttr), child) } } object Optimize extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches } test("check for invalid plan after execution of rule") { val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze assert(analyzed.resolved) val message = intercept[TreeNodeException[LogicalPlan]] { Optimize.execute(analyzed) }.getMessage val ruleName = OptimizeRuleBreakSI.ruleName assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI")) assert(message.contains("the structural integrity of the plan is broken")) } }
Example 26
Source File: RewriteDistinctAggregatesSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) val nullInt = Literal(null, IntegerType) val nullString = Literal(null, StringType) val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { case Aggregate(_, _, Aggregate(_, _, _: Expand)) => case _ => fail(s"Plan is not rewritten:\n$rewrite") } test("single distinct group") { val input = testRelation .groupBy('a)(countDistinct('e)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), max('b).as('agg2)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with partial aggregates") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with non-partial aggregates") { val input = testRelation .groupBy('a)( countDistinct('b, 'c), countDistinct('d), CollectSet('b).toAggregateExpression()) .analyze checkRewrite(RewriteDistinctAggregates(input)) } }
Example 27
Source File: ComputeCurrentTimeSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 28
Source File: AggregateOptimizeSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} class AggregateOptimizeSuite extends PlanTest { override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), FoldablePropagation, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("remove literals in grouping expression") { val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("do not remove all grouping expressions if they are all literals") { val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) comparePlans(optimized, correctAnswer) } test("Remove aliased literals") { val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("remove repetition in grouping expression") { val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } }
Example 29
Source File: subquery.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]()) val sameResult = sameSchema.find(_.sameResult(sub.plan)) if (sameResult.isDefined) { sub.withNewPlan(sameResult.get) } else { sameSchema += sub.plan sub } } } }
Example 30
Source File: TakeOrderedAndProjectSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) } private def generateRandomInputData(): DataFrame = { val schema = new StructType() .add("a", IntegerType, nullable = false) .add("b", IntegerType, nullable = false) val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) val limit = 250 val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, SortExec(sortOrder, true, input))), sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter( TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, ProjectExec(Seq(input.output.last), SortExec(sortOrder, true, input)))), sortAnswers = false) } } }
Example 31
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.test.SharedSQLContext class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height // histogram usually contains hundreds of buckets. So we need to test // ApproxCountDistinctForIntervals with large number of endpoints // (the number of endpoints == the number of buckets + 1). test("test ApproxCountDistinctForIntervals with large number of endpoints") { val table = "approx_count_distinct_for_intervals_tbl" withTable(table) { (1 to 100000).toDF("col").createOrReplaceTempView(table) // percentiles of 0, 0.001, 0.002 ... 0.999, 1 val endpoints = (0 to 1000).map(_ * 100000 / 1000) // Since approx_count_distinct_for_intervals is not a public function, here we do // the computation by constructing logical plan. val relation = spark.table(table).logicalPlan val attr = relation.output.find(_.name == "col").get val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) val aggExpr = aggFunc.toAggregateExpression() val namedExpr = Alias(aggExpr, aggExpr.toString)() val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) .executedPlan.executeTake(1).head val ndvArray = ndvsRow.getArray(0).toLongArray() assert(endpoints.length == ndvArray.length + 1) // Each bucket has 100 distinct values. val expectedNdv = 100 for (i <- ndvArray.indices) { val ndv = ndvArray(i) val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") } } } }
Example 32
Source File: AnnotationParser.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.parser import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{AnnotationReference, Expression, Literal} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String protected def toTableMetadata(metadata: Map[String, Expression]): Metadata = { val res = new MetadataBuilder() metadata.foreach { case (k, v:Literal) => v.dataType match { case StringType => if (k.equals("?")) { sys.error("column metadata key can not be ?") } if (k.equals("*")) { sys.error("column metadata key can not be *") } res.putString(k, v.value.asInstanceOf[UTF8String].toString) case LongType => res.putLong(k, v.value.asInstanceOf[Long]) case DoubleType => res.putDouble(k, v.value.asInstanceOf[Double]) case NullType => res.putString(k, null) case a:ArrayType => res.putString(k, v.value.toString) } case (k, v:AnnotationReference) => sys.error("column metadata can not have a reference to another column metadata") } res.build() } }
Example 33
Source File: MetadataAccessorSuite.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.types import org.apache.spark.sql.catalyst.expressions.{Literal, Expression} import org.scalatest.FunSuite // scalastyle:off magic.number class MetadataAccessorSuite extends FunSuite { test("expression map is written correctly to Metadata") { val expressionMap = Map[String, Expression] ( "stringKey" -> Literal.create("stringValue", StringType), "longKey" -> Literal.create(10L, LongType), "doubleKey" -> Literal.create(1.234, DoubleType), "nullKey" -> Literal.create(null, NullType) ) val actual = MetadataAccessor.expressionMapToMetadata(expressionMap) assertResult("stringValue")(actual.getString("stringKey")) assertResult(10)(actual.getLong("longKey")) assertResult(1.234)(actual.getDouble("doubleKey")) assertResult(null)(actual.getString("nullKey")) } test("metadata propagation works correctly") { val oldMetadata = new MetadataBuilder() .putString("key1", "value1") .putString("key2", "value2") .putLong("key3", 10L) .build() val newMetadata = new MetadataBuilder() .putString("key1", "overriden") .putString("key4", "value4") .build() val expected = new MetadataBuilder() .putString("key1", "overriden") .putString("key2", "value2") .putLong("key3", 10L) .putString("key4", "value4") .build() val actual = MetadataAccessor.propagateMetadata(oldMetadata, newMetadata) assertResult(expected)(actual) } test("filter metadata works correctly") { val metadata = new MetadataBuilder() .putString("key1", "value1") .putString("key2", "value2") .putLong("key3", 10L) .build() val expected1 = new MetadataBuilder() .putString("key1", "value1") .build() assertResult(expected1)(MetadataAccessor.filterMetadata(metadata, ("key1" :: Nil).toSet)) assertResult(metadata)(MetadataAccessor.filterMetadata(metadata, ("*" :: Nil).toSet)) } }
Example 34
Source File: CollapseExpandSuite.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.CollapseExpandSuite.SqlLikeCatalystSourceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.sources.sql.SqlLikeRelation import org.apache.spark.sql.sources.{BaseRelation, CatalystSource, Table} import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.util.PlanComparisonUtils._ import org.apache.spark.sql.{GlobalSapSQLContext, Row} import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar class CollapseExpandSuite extends FunSuite with MockitoSugar with GlobalSapSQLContext { case object Leaf extends LeafNode { override def output: Seq[Attribute] = Seq.empty } test("Expansion with a single sequence of projections is correctly collapsed") { val expand = Expand( Seq(Seq('a.string, Literal(1))), Seq('a.string, 'gid.int), Leaf) val collapsed = CollapseExpand(expand) assertResult(normalizeExprIds(Project(Seq('a.string, Literal(1) as "gid"), Leaf)))( normalizeExprIds(collapsed)) } test("Expansion with multiple projections is correctly collapsed") { val expand = Expand( Seq( Seq('a.string, Literal(1)), Seq('b.string, Literal(1))), Seq('a.string, 'gid1.int, 'b.string, 'gid2.int), Leaf) val collapsed = CollapseExpand(expand) assertResult( normalizeExprIds( Project(Seq( 'a.string, Literal(1) as "gid1", 'b.string, Literal(1) as "gid2"), Leaf)))(normalizeExprIds(collapsed)) } test("Expand pushdown integration") { val relation = mock[SqlLikeCatalystSourceRelation] when(relation.supportsLogicalPlan(any[Expand])) .thenReturn(true) when(relation.isMultiplePartitionExecution(any[Seq[CatalystSource]])) .thenReturn(true) when(relation.schema) .thenReturn(StructType(StructField("foo", StringType) :: Nil)) when(relation.relationName) .thenReturn("t") when(relation.logicalPlanToRDD(any[LogicalPlan])) .thenReturn(sc.parallelize(Seq(Row("a", 1), Row("b", 1), Row("a", 1)))) sqlc.baseRelationToDataFrame(relation).registerTempTable("t") val dataFrame = sqlc.sql("SELECT COUNT(DISTINCT foo) FROM t") val Seq(Row(ct)) = dataFrame.collect().toSeq assertResult(2)(ct) } } object CollapseExpandSuite { abstract class SqlLikeCatalystSourceRelation extends BaseRelation with Table with SqlLikeRelation with CatalystSource }
Example 35
Source File: HivemallUtils.scala From hivemall-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, Row, UserDefinedFunction} object HivemallUtils { // # of maximum dimensions for feature vectors val maxDims = 100000000 def funcVectorizer(dense: Boolean = false, dims: Int = maxDims) : UserDefinedFunction = { udf(funcVectorizerImpl(dense, dims)) } private def funcVectorizerImpl(dense: Boolean, dims: Int) : Seq[String] => Vector = { if (dense) { // Dense features i: Seq[String] => { val features = new Array[Double](dims) i.map { ft => val s = ft.split(":").ensuring(_.size == 2) features(s(0).toInt) = s(1).toDouble } Vectors.dense(features) } } else { // Sparse features i: Seq[String] => { val features = i.map { ft => // val s = ft.split(":").ensuring(_.size == 2) val s = ft.split(":") (s(0).toInt, s(1).toDouble) } Vectors.sparse(dims, features) } } } }
Example 36
Source File: HiveClientSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.client import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.types.IntegerType class HiveClientSuite extends SparkFunSuite { private val clientBuilder = new HiveClientBuilder private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val testPartitionCount = 5 val storageFormat = CatalogStorageFormat( locationUri = None, inputFormat = None, outputFormat = None, serde = None, compressed = false, properties = Map.empty) val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, false) val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)") val partitions = (1 to testPartitionCount).map { part => CatalogTablePartition(Map("part" -> part.toString), storageFormat) } client.createPartitions( "default", "test", partitions, ignoreIfExists = false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3)))) assert(filteredPartitions.size == testPartitionCount) } }
Example 37
Source File: LocalRelation.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) def apply(output1: StructField, output: StructField*): LocalRelation = { new LocalRelation(StructType(output1 +: output).toAttributes) } def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } } case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.") override final def newInstance(): this.type = { LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] } override protected def stringArgs: Iterator[Any] = { if (data.isEmpty) { Iterator("<empty>", output) } else { Iterator(output) } } override def sameResult(plan: LogicalPlan): Boolean = { plan.canonicalized match { case LocalRelation(otherOutput, otherData) => otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data case _ => false } } override lazy val statistics = Statistics(sizeInBytes = (output.map(n => BigInt(n.dataType.defaultSize))).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) val types = output.map(_.dataType) val rows = data.map { row => val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } cells.mkString("(", ", ", ")") } "VALUES " + rows.mkString(", ") + " AS " + inlineTableName + output.map(_.name).mkString("(", ", ", ")") } }
Example 38
Source File: SubstituteUnresolvedOrdinals.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.types.IntegerType class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other } withOrigin(s.origin)(s.copy(order = newOrders)) case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) => val newGroups = a.groupingExpressions.map { case ordinal @ Literal(index: Int, IntegerType) => withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) case other => other } withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) } }
Example 39
Source File: RuleExecutorSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } val message = intercept[TreeNodeException[LogicalPlan]] { ToFixedPoint.execute(Literal(100)) }.getMessage assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } }
Example 40
Source File: SubstituteUnresolvedOrdinalsSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.SimpleCatalystConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) private lazy val a = testRelation2.output(0) private lazy val b = testRelation2.output(1) test("unresolved ordinal should not be unresolved") { // Expression OrderByOrdinal is unresolved. assert(!UnresolvedOrdinal(0).resolved) } test("order by ordinal") { // Tests order by ordinal, apply single rule. val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan), testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) // Tests order by ordinal, do full analysis checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) // order by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } test("group by ordinal") { // Tests group by ordinal, apply single rule. val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan2), testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) // Tests group by ordinal, do full analysis checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) // group by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } }
Example 41
Source File: ResolveInlineTablesSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types.{LongType, NullType} class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) assert(ResolveInlineTables(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted = ResolveInlineTables.convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) assert(converted.data(0).getLong(0) == 1L) assert(converted.data(1).getLong(0) == 2L) } test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted1 = ResolveInlineTables.convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val converted2 = ResolveInlineTables.convert(table2) assert(converted2.schema.fields(0).nullable) } }
Example 42
Source File: PartitioningSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} class PartitioningSuite extends SparkFunSuite { test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { val expressions = Seq(Literal(2), Literal(3)) // Consider two HashPartitionings that have the same _set_ of hash expressions but which are // created with different orderings of those expressions: val partitioningA = HashPartitioning(expressions, 100) val partitioningB = HashPartitioning(expressions.reverse, 100) // These partitionings are not considered equal: assert(partitioningA != partitioningB) // However, they both satisfy the same clustered distribution: val distribution = ClusteredDistribution(expressions) assert(partitioningA.satisfies(distribution)) assert(partitioningB.satisfies(distribution)) // These partitionings compute different hashcodes for the same input row: def computeHashCode(partitioning: HashPartitioning): Int = { val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) hashExprProj.apply(InternalRow.empty).hashCode() } assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) // Thus, these partitionings are incompatible: assert(!partitioningA.compatibleWith(partitioningB)) assert(!partitioningB.compatibleWith(partitioningA)) assert(!partitioningA.guarantees(partitioningB)) assert(!partitioningB.guarantees(partitioningA)) // Just to be sure that we haven't cheated by having these methods always return false, // check that identical partitionings are still compatible with and guarantee each other: assert(partitioningA === partitioningA) assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } }
Example 43
Source File: RewriteDistinctAggregatesSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) val nullInt = Literal(null, IntegerType) val nullString = Literal(null, StringType) val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { case Aggregate(_, _, Aggregate(_, _, _: Expand)) => case _ => fail(s"Plan is not rewritten:\n$rewrite") } test("single distinct group") { val input = testRelation .groupBy('a)(countDistinct('e)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), max('b).as('agg2)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with non-partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), CollectSet('b).toAggregateExpression().as('agg2)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with partial aggregates") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with non-partial aggregates") { val input = testRelation .groupBy('a)( countDistinct('b, 'c), countDistinct('d), CollectSet('b).toAggregateExpression()) .analyze checkRewrite(RewriteDistinctAggregates(input)) } }
Example 44
Source File: ComputeCurrentTimeSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 45
Source File: AggregateOptimizeSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), FoldablePropagation, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("remove literals in grouping expression") { val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("do not remove all grouping expressions if they are all literals") { val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) comparePlans(optimized, correctAnswer) } test("Remove aliased literals") { val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("remove repetition in grouping expression") { val input = LocalRelation('a.int, 'b.int, 'c.int) val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } }
Example 46
Source File: ExchangeSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode1)) assert(mode2.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode3)) assert(mode3.compatibleWith(mode3)) } test("BroadcastExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchangeExec(hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchangeExec(hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } test("ShuffleExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) val exchange1 = ShuffleExchange(part1, plan) val exchange2 = ShuffleExchange(part1, plan) val part2 = HashPartitioning(output, 2) val exchange3 = ShuffleExchange(part2, plan) val part3 = HashPartitioning(output ++ output, 2) val exchange4 = ShuffleExchange(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(exchange5 sameResult exchange5) assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(!exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } }
Example 47
Source File: TakeOrderedAndProjectSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) } private def generateRandomInputData(): DataFrame = { val schema = new StructType() .add("a", IntegerType, nullable = false) .add("b", IntegerType, nullable = false) val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) val limit = 250 val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, SortExec(sortOrder, true, input))), sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter( TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, ProjectExec(Seq(input.output.last), SortExec(sortOrder, true, input)))), sortAnswers = false) } } }
Example 48
Source File: MessageDelimiter.scala From spark-cep with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.streaming.sources import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, EmptyRow, Literal} import org.apache.spark.sql.types.StructType class MessageDelimiter extends MessageToRowConverter with Logging { val delimiter = " " def toRow(msg: String, schema: StructType): InternalRow = { val splitted = msg.split(delimiter).map(Literal(_)) val casted = splitted.indices.map(i => Cast(splitted(i), schema(i).dataType).eval(EmptyRow)) InternalRow.fromSeq(casted) } def toMessage(row: Row): String = row.mkString(delimiter) } trait MessageToRowConverter extends Serializable { def toRow(message: String, schema: StructType): InternalRow def toMessage(row: Row): String }
Example 49
Source File: DeltaInvariantCheckerExec.scala From delta with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.delta.schema import org.apache.spark.sql.delta.DeltaErrors import org.apache.spark.sql.delta.schema.Invariants.NotNull import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BindReferences, Expression, GetStructField, Literal, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.{NullType, StructType} private def buildExtractors(invariant: Invariant): Option[Expression] = { assert(invariant.column.nonEmpty) val topLevelColumn = invariant.column.head val topLevelRefOpt = output.collectFirst { case a: AttributeReference if SchemaUtils.DELTA_COL_RESOLVER(a.name, topLevelColumn) => a } val rejectColumnNotFound = isNullNotOkay(invariant) if (topLevelRefOpt.isEmpty) { if (rejectColumnNotFound) { throw DeltaErrors.notNullInvariantException(invariant) } } if (invariant.column.length == 1) { topLevelRefOpt.map(BindReferences.bindReference[Expression](_, output)) } else { topLevelRefOpt.flatMap { topLevelRef => val boundTopLevel = BindReferences.bindReference[Expression](topLevelRef, output) try { val nested = invariant.column.tail.foldLeft(boundTopLevel) { case (e, fieldName) => e.dataType match { case StructType(fields) => val ordinal = fields.indexWhere(f => SchemaUtils.DELTA_COL_RESOLVER(f.name, fieldName)) if (ordinal == -1) { throw new IndexOutOfBoundsException(s"Not nullable column not found in struct: " + s"${fields.map(_.name).mkString("[", ",", "]")}") } GetStructField(e, ordinal, Some(fieldName)) case _ => throw new UnsupportedOperationException( "Invariants on nested fields other than StructTypes are not supported.") } } Some(nested) } catch { case i: IndexOutOfBoundsException if rejectColumnNotFound => throw InvariantViolationException(invariant, i.getMessage) case _: IndexOutOfBoundsException if !rejectColumnNotFound => None } } } } override protected def doExecute(): RDD[InternalRow] = { if (invariants.isEmpty) return child.execute() val boundRefs = invariants.map { invariant => CheckDeltaInvariant(buildExtractors(invariant).getOrElse(Literal(null, NullType)), invariant) } child.execute().mapPartitionsInternal { rows => val assertions = GenerateUnsafeProjection.generate(boundRefs) rows.map { row => assertions(row) row } } } override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning }
Example 50
Source File: LocalRelation.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) def apply(output1: StructField, output: StructField*): LocalRelation = { new LocalRelation(StructType(output1 +: output).toAttributes) } def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } } case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.") override final def newInstance(): this.type = { LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] } override protected def stringArgs: Iterator[Any] = { if (data.isEmpty) { Iterator("<empty>", output) } else { Iterator(output) } } override def sameResult(plan: LogicalPlan): Boolean = { plan.canonicalized match { case LocalRelation(otherOutput, otherData) => otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data case _ => false } } override lazy val statistics = Statistics(sizeInBytes = (output.map(n => BigInt(n.dataType.defaultSize))).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) val types = output.map(_.dataType) val rows = data.map { row => val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } cells.mkString("(", ", ", ")") } "VALUES " + rows.mkString(", ") + " AS " + inlineTableName + output.map(_.name).mkString("(", ", ", ")") } }
Example 51
Source File: SubstituteUnresolvedOrdinals.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.types.IntegerType class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other } withOrigin(s.origin)(s.copy(order = newOrders)) case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) => val newGroups = a.groupingExpressions.map { case ordinal @ Literal(index: Int, IntegerType) => withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) case other => other } withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) } }
Example 52
Source File: RuleExecutorSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } val message = intercept[TreeNodeException[LogicalPlan]] { ToFixedPoint.execute(Literal(100)) }.getMessage assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } }
Example 53
Source File: SubstituteUnresolvedOrdinalsSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.SimpleCatalystConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) private lazy val a = testRelation2.output(0) private lazy val b = testRelation2.output(1) test("unresolved ordinal should not be unresolved") { // Expression OrderByOrdinal is unresolved. assert(!UnresolvedOrdinal(0).resolved) } test("order by ordinal") { // Tests order by ordinal, apply single rule. val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan), testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) // Tests order by ordinal, do full analysis checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) // order by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } test("group by ordinal") { // Tests group by ordinal, apply single rule. val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan2), testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) // Tests group by ordinal, do full analysis checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) // group by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } }
Example 54
Source File: ResolveInlineTablesSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types.{LongType, NullType} class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { ResolveInlineTables.validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { ResolveInlineTables.validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) assert(ResolveInlineTables(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted = ResolveInlineTables.convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) assert(converted.data(0).getLong(0) == 1L) assert(converted.data(1).getLong(0) == 2L) } test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted1 = ResolveInlineTables.convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val converted2 = ResolveInlineTables.convert(table2) assert(converted2.schema.fields(0).nullable) } }
Example 55
Source File: PartitioningSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} class PartitioningSuite extends SparkFunSuite { test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { val expressions = Seq(Literal(2), Literal(3)) // Consider two HashPartitionings that have the same _set_ of hash expressions but which are // created with different orderings of those expressions: val partitioningA = HashPartitioning(expressions, 100) val partitioningB = HashPartitioning(expressions.reverse, 100) // These partitionings are not considered equal: assert(partitioningA != partitioningB) // However, they both satisfy the same clustered distribution: val distribution = ClusteredDistribution(expressions) assert(partitioningA.satisfies(distribution)) assert(partitioningB.satisfies(distribution)) // These partitionings compute different hashcodes for the same input row: def computeHashCode(partitioning: HashPartitioning): Int = { val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) hashExprProj.apply(InternalRow.empty).hashCode() } assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) // Thus, these partitionings are incompatible: assert(!partitioningA.compatibleWith(partitioningB)) assert(!partitioningB.compatibleWith(partitioningA)) assert(!partitioningA.guarantees(partitioningB)) assert(!partitioningB.guarantees(partitioningA)) // Just to be sure that we haven't cheated by having these methods always return false, // check that identical partitionings are still compatible with and guarantee each other: assert(partitioningA === partitioningA) assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } }
Example 56
Source File: RewriteDistinctAggregatesSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) val nullInt = Literal(null, IntegerType) val nullString = Literal(null, StringType) val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { case Aggregate(_, _, Aggregate(_, _, _: Expand)) => case _ => fail(s"Plan is not rewritten:\n$rewrite") } test("single distinct group") { val input = testRelation .groupBy('a)(countDistinct('e)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), max('b).as('agg2)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with non-partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), CollectSet('b).toAggregateExpression().as('agg2)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with partial aggregates") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with non-partial aggregates") { val input = testRelation .groupBy('a)( countDistinct('b, 'c), countDistinct('d), CollectSet('b).toAggregateExpression()) .analyze checkRewrite(RewriteDistinctAggregates(input)) } }
Example 57
Source File: ComputeCurrentTimeSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 58
Source File: AggregateOptimizeSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), FoldablePropagation, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("remove literals in grouping expression") { val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("do not remove all grouping expressions if they are all literals") { val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) comparePlans(optimized, correctAnswer) } test("Remove aliased literals") { val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("remove repetition in grouping expression") { val input = LocalRelation('a.int, 'b.int, 'c.int) val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } }
Example 59
Source File: ExchangeSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode1)) assert(mode2.compatibleWith(mode2)) assert(!mode2.compatibleWith(mode3)) assert(mode3.compatibleWith(mode3)) } test("BroadcastExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchangeExec(hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchangeExec(hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3, sparkContext.sparkUser) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } test("ShuffleExchange same result") { val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) val exchange1 = ShuffleExchange(part1, plan) val exchange2 = ShuffleExchange(part1, plan) val part2 = HashPartitioning(output, 2) val exchange3 = ShuffleExchange(part2, plan) val part3 = HashPartitioning(output ++ output, 2) val exchange4 = ShuffleExchange(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4, sparkContext.sparkUser) assert(exchange1 sameResult exchange1) assert(exchange2 sameResult exchange2) assert(exchange3 sameResult exchange3) assert(exchange4 sameResult exchange4) assert(exchange5 sameResult exchange5) assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) assert(!exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } }
Example 60
Source File: TakeOrderedAndProjectSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) } private def generateRandomInputData(): DataFrame = { val schema = new StructType() .add("a", IntegerType, nullable = false) .add("b", IntegerType, nullable = false) val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) val limit = 250 val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, SortExec(sortOrder, true, input))), sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter( TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, ProjectExec(Seq(input.output.last), SortExec(sortOrder, true, input)))), sortAnswers = false) } } }
Example 61
Source File: RuleExecutorSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } }
Example 62
Source File: TypedExpressionEncoder.scala From frameless with Apache License 2.0 | 5 votes |
package frameless import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If, Literal} import org.apache.spark.sql.types.StructType object TypedExpressionEncoder { def targetStructType[A](encoder: TypedEncoder[A]): StructType = { encoder.catalystRepr match { case x: StructType => if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true))) else x case dt => new StructType().add("_1", dt, nullable = encoder.nullable) } } def apply[T: TypedEncoder]: ExpressionEncoder[T] = { val encoder = TypedEncoder[T] val schema = targetStructType(encoder) val in = BoundReference(0, encoder.jvmRepr, encoder.nullable) val (out, toRowExpressions) = encoder.toCatalyst(in) match { case If(_, _, x: CreateNamedStruct) => val out = BoundReference(0, encoder.catalystRepr, encoder.nullable) (out, x.flatten) case other => val out = GetColumnByOrdinal(0, encoder.catalystRepr) (out, CreateNamedStruct(Literal("_1") :: other :: Nil).flatten) } new ExpressionEncoder[T]( schema = schema, flat = false, serializer = toRowExpressions, deserializer = encoder.fromCatalyst(out), clsTag = encoder.classTag ) } }
Example 63
Source File: package.scala From frameless with Apache License 2.0 | 5 votes |
package frameless import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Literal package object functions extends Udf with UnaryFunctions { object aggregate extends AggregateFunctions object nonAggregate extends NonAggregateFunctions def lit[A: TypedEncoder, T](value: A): TypedColumn[T, A] = { val encoder = TypedEncoder[A] if (ScalaReflection.isNativeType(encoder.jvmRepr) && encoder.catalystRepr == encoder.jvmRepr) { val expr = Literal(value, encoder.catalystRepr) new TypedColumn(expr) } else { val expr = FramelessLit(value, encoder) new TypedColumn(expr) } } }
Example 64
package frameless.functions import frameless.TypedEncoder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NonSQLExpression} import org.apache.spark.sql.types.DataType case class FramelessLit[A](obj: A, encoder: TypedEncoder[A]) extends Expression with NonSQLExpression { override def nullable: Boolean = encoder.nullable override def toString: String = s"FramelessLit($obj)" def eval(input: InternalRow): Any = { val ctx = new CodegenContext() val eval = genCode(ctx) val codeBody = s""" public scala.Function1<InternalRow, Object> generate(Object[] references) { return new FramelessLitEvalImpl(references); } class FramelessLitEvalImpl extends scala.runtime.AbstractFunction1<InternalRow, Object> { private final Object[] references; ${ctx.declareMutableStates()} ${ctx.declareAddedFunctions()} public FramelessLitEvalImpl(Object[] references) { this.references = references; ${ctx.initMutableStates()} } public java.lang.Object apply(java.lang.Object z) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) z; ${eval.code} return ${eval.isNull} ? ((Object)null) : ((Object)${eval.value}); } } """ val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) val (clazz, _) = CodeGenerator.compile(code) val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] codegen(input) } def dataType: DataType = encoder.catalystRepr def children: Seq[Expression] = Nil override def genCode(ctx: CodegenContext): ExprCode = { encoder.toCatalyst(new Literal(obj, encoder.jvmRepr)).genCode(ctx) } protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ??? }
Example 65
Source File: RuleExecutorSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } }
Example 66
Source File: AggregateOptimizeSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Nil } //用聚合代替distinct test("replace distinct with aggregate") { val input = LocalRelation('a.int, 'b.int) val query = Distinct(input) val optimized = Optimize.execute(query.analyze) val correctAnswer = Aggregate(input.output, input.output, input) comparePlans(optimized, correctAnswer) } //在表达式分组中移除文字 test("remove literals in grouping expression") { val input = LocalRelation('a.int, 'b.int) val query = input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(query) val correctAnswer = input.groupBy('a)(sum('b)) comparePlans(optimized, correctAnswer) } }
Example 67
Source File: ExtraStrategiesSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package test.org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.UTF8String //快速操作 case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) sparkContext.parallelize(Seq(row)) } //Nil是一个空的List override def children: Seq[SparkPlan] = Nil } //测试策略 object TestStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Project(Seq(attr), _) if attr.name == "a" => //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 FastOperator(attr.toAttribute :: Nil) :: Nil //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 case _ => Nil } } //额外的策略集 class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("insert an extraStrategy") {//插入一个额外的策略 try { //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 sqlContext.experimental.extraStrategies = TestStrategy :: Nil val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( df.select("a"), Row("so fast")) checkAnswer( df.select("a", "b"), Row("so slow", 1)) } finally { //Nil是一个空的List,::向队列的头部追加数据,创造新的列表 sqlContext.experimental.extraStrategies = Nil } } }
Example 68
Source File: StatefulApproxQuantile.scala From deequ with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, Literal} import org.apache.spark.sql.types._ private[sql] case class StatefulApproxQuantile( child: Expression, accuracyExpression: Expression, override val mutableAggBufferOffset: Int, override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes { def this(child: Expression, accuracyExpression: Expression) = { this(child, accuracyExpression, 0, 0) } def this(child: Expression) = { this(child, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) } // Mark as lazy so that accuracyExpression is not evaluated during tree transformation. private lazy val accuracy: Double = accuracyExpression.eval().asInstanceOf[Double] override def inputTypes: Seq[AbstractDataType] = { Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType) } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() if (defaultCheck.isFailure) { defaultCheck } else if (!accuracyExpression.foldable) { TypeCheckFailure(s"The accuracy provided must be a constant literal") } else if (accuracy <= 0) { TypeCheckFailure( s"The accuracy provided must be a positive integer literal (current value = $accuracy)") } else { TypeCheckSuccess } } override def createAggregationBuffer(): PercentileDigest = { val relativeError = 1.0D / accuracy new PercentileDigest(relativeError) } override def update(buffer: PercentileDigest, inputRow: InternalRow): PercentileDigest = { val value = child.eval(inputRow) // Ignore empty rows, for example: percentile_approx(null) if (value != null) { buffer.add(value.asInstanceOf[Double]) } buffer } override def merge(buffer: PercentileDigest, other: PercentileDigest): PercentileDigest = { buffer.merge(other) buffer } override def eval(buffer: PercentileDigest): Any = { // instead of evaluating the PercentileDigest quantile summary here, // serialize the digest and return it as byte array serialize(buffer) } override def withNewMutableAggBufferOffset(newOffset: Int): StatefulApproxQuantile = copy(mutableAggBufferOffset = newOffset) override def withNewInputAggBufferOffset(newOffset: Int): StatefulApproxQuantile = copy(inputAggBufferOffset = newOffset) override def children: Seq[Expression] = Seq(child, accuracyExpression) // Returns null for empty inputs override def nullable: Boolean = true override def dataType: DataType = BinaryType override def prettyName: String = "percentile_approx" override def serialize(digest: PercentileDigest): Array[Byte] = { ApproximatePercentile.serializer.serialize(digest) } override def deserialize(bytes: Array[Byte]): PercentileDigest = { ApproximatePercentile.serializer.deserialize(bytes) } }
Example 69
Source File: DeequFunctions.scala From deequ with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import com.amazon.deequ.analyzers.KLLSketch import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, StatefulApproxQuantile, StatefulHyperloglogPlus} import org.apache.spark.sql.catalyst.expressions.Literal def stateful_datatype(column: Column): Column = { val statefulDataType = new StatefulDataType() statefulDataType(column) } def stateful_kll( column: Column, sketchSize: Int, shrinkingFactor: Double): Column = { val statefulKLL = new StatefulKLLSketch(sketchSize, shrinkingFactor) statefulKLL(column) } }
Example 70
Source File: ValueInterval.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ def intersect(r1: ValueInterval, r2: ValueInterval, dt: DataType): (Option[Any], Option[Any]) = { (r1, r2) match { case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => // binary/string types don't support intersecting. (None, None) case (n1: NumericValueInterval, n2: NumericValueInterval) => // Choose the maximum of two min values, and the minimum of two max values. val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMax = if (n1.max <= n2.max) n1.max else n2.max (Some(EstimationUtils.fromDouble(newMin, dt)), Some(EstimationUtils.fromDouble(newMax, dt))) } } }
Example 71
Source File: LocalRelation.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) def apply(output1: StructField, output: StructField*): LocalRelation = { new LocalRelation(StructType(output1 +: output).toAttributes) } def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } } case class LocalRelation( output: Seq[Attribute], data: Seq[InternalRow] = Nil, // Indicates whether this relation has data from a streaming source. override val isStreaming: Boolean = false) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.") override final def newInstance(): this.type = { LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type] } override protected def stringArgs: Iterator[Any] = { if (data.isEmpty) { Iterator("<empty>", output) } else { Iterator(output) } } override def computeStats(): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) val types = output.map(_.dataType) val rows = data.map { row => val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } cells.mkString("(", ", ", ")") } "VALUES " + rows.mkString(", ") + " AS " + inlineTableName + output.map(_.name).mkString("(", ", ", ")") } }
Example 72
Source File: SubstituteUnresolvedOrdinals.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other } withOrigin(s.origin)(s.copy(order = newOrders)) case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) => val newGroups = a.groupingExpressions.map { case ordinal @ Literal(index: Int, IntegerType) => withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) case other => other } withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) } }
Example 73
Source File: QueryPlanSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.types.IntegerType class QueryPlanSuite extends SparkFunSuite { test("origin remains the same after mapExpressions (SPARK-23823)") { CurrentOrigin.setPosition(0, 0) val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId) val query = plans.DslLogicalPlan(plans.table("table")).select(column) CurrentOrigin.reset() val mappedQuery = query mapExpressions { case _: Expression => Literal(1) } val mappedOrigin = mappedQuery.expressions.apply(0).origin assert(mappedOrigin == Origin.apply(Some(0), Some(0))) } }
Example 74
Source File: RuleExecutorSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } val message = intercept[TreeNodeException[LogicalPlan]] { ToFixedPoint.execute(Literal(100)) }.getMessage assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } test("structural integrity checker") { object WithSIChecker extends RuleExecutor[Expression] { override protected def isPlanIntegral(expr: Expression): Boolean = expr match { case IntegerLiteral(_) => true case _ => false } val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(WithSIChecker.execute(Literal(10)) === Literal(9)) val message = intercept[TreeNodeException[LogicalPlan]] { WithSIChecker.execute(Literal(10.1)) }.getMessage assert(message.contains("the structural integrity of the plan is broken")) } }
Example 75
Source File: SubstituteUnresolvedOrdinalsSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.internal.SQLConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val a = testRelation2.output(0) private lazy val b = testRelation2.output(1) test("unresolved ordinal should not be unresolved") { // Expression OrderByOrdinal is unresolved. assert(!UnresolvedOrdinal(0).resolved) } test("order by ordinal") { // Tests order by ordinal, apply single rule. val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan), testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) // Tests order by ordinal, do full analysis checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) // order by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.ORDER_BY_ORDINAL -> false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } test("group by ordinal") { // Tests group by ordinal, apply single rule. val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) comparePlans( new SubstituteUnresolvedOrdinals(conf).apply(plan2), testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) // Tests group by ordinal, do full analysis checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) // group by ordinal can be turned off by config comparePlans( new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.GROUP_BY_ORDINAL -> false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } }
Example 76
Source File: ResolveInlineTablesSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) assert(ResolveInlineTables(conf)(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted = ResolveInlineTables(conf).convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) assert(converted.data(0).getLong(0) == 1L) assert(converted.data(1).getLong(0) == 2L) } test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone(conf).apply(table) val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) assert(data.size == 1) assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val converted1 = ResolveInlineTables(conf).convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val converted2 = ResolveInlineTables(conf).convert(table2) assert(converted2.schema.fields(0).nullable) } }
Example 77
Source File: OptimizerStructuralIntegrityCheckerSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { object OptimizeRuleBreakSI extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(projectList, child) => val newAttr = UnresolvedAttribute("unresolvedAttr") Project(projectList ++ Seq(newAttr), child) } } object Optimize extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches } test("check for invalid plan after execution of rule") { val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze assert(analyzed.resolved) val message = intercept[TreeNodeException[LogicalPlan]] { Optimize.execute(analyzed) }.getMessage val ruleName = OptimizeRuleBreakSI.ruleName assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI")) assert(message.contains("the structural integrity of the plan is broken")) } }
Example 78
Source File: RewriteDistinctAggregatesSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) val nullInt = Literal(null, IntegerType) val nullString = Literal(null, StringType) val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { case Aggregate(_, _, Aggregate(_, _, _: Expand)) => case _ => fail(s"Plan is not rewritten:\n$rewrite") } test("single distinct group") { val input = testRelation .groupBy('a)(countDistinct('e)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("single distinct group with partial aggregates") { val input = testRelation .groupBy('a, 'd)( countDistinct('e, 'c).as('agg1), max('b).as('agg2)) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) } test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with partial aggregates") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with non-partial aggregates") { val input = testRelation .groupBy('a)( countDistinct('b, 'c), countDistinct('d), CollectSet('b).toAggregateExpression()) .analyze checkRewrite(RewriteDistinctAggregates(input)) } }
Example 79
Source File: ComputeCurrentTimeSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils class ComputeCurrentTimeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) } test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) val min = System.currentTimeMillis() * 1000 val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 val lits = new scala.collection.mutable.ArrayBuffer[Long] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Long] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } test("analyzer should replace current_date with literals") { val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) val lits = new scala.collection.mutable.ArrayBuffer[Int] plan.transformAllExpressions { case e: Literal => lits += e.value.asInstanceOf[Int] e } assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } }
Example 80
Source File: AggregateOptimizeSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} class AggregateOptimizeSuite extends PlanTest { override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), FoldablePropagation, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("remove literals in grouping expression") { val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("do not remove all grouping expressions if they are all literals") { val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) comparePlans(optimized, correctAnswer) } test("Remove aliased literals") { val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } test("remove repetition in grouping expression") { val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } }
Example 81
Source File: TakeOrderedAndProjectSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) } private def generateRandomInputData(): DataFrame = { val schema = new StructType() .add("a", IntegerType, nullable = false) .add("b", IntegerType, nullable = false) val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) val limit = 250 val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, SortExec(sortOrder, true, input))), sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { withClue(s"seed = $seed") { checkThatPlansAgree( generateRandomInputData(), input => noOpFilter( TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, ProjectExec(Seq(input.output.last), SortExec(sortOrder, true, input)))), sortAnswers = false) } } }
Example 82
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.test.SharedSQLContext class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height // histogram usually contains hundreds of buckets. So we need to test // ApproxCountDistinctForIntervals with large number of endpoints // (the number of endpoints == the number of buckets + 1). test("test ApproxCountDistinctForIntervals with large number of endpoints") { val table = "approx_count_distinct_for_intervals_tbl" withTable(table) { (1 to 100000).toDF("col").createOrReplaceTempView(table) // percentiles of 0, 0.001, 0.002 ... 0.999, 1 val endpoints = (0 to 1000).map(_ * 100000 / 1000) // Since approx_count_distinct_for_intervals is not a public function, here we do // the computation by constructing logical plan. val relation = spark.table(table).logicalPlan val attr = relation.output.find(_.name == "col").get val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) val aggExpr = aggFunc.toAggregateExpression() val namedExpr = Alias(aggExpr, aggExpr.toString)() val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) .executedPlan.executeTake(1).head val ndvArray = ndvsRow.getArray(0).toLongArray() assert(endpoints.length == ndvArray.length + 1) // Each bucket has 100 distinct values. val expectedNdv = 100 for (i <- ndvArray.indices) { val ndv = ndvArray(i) val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") } } } }
Example 83
Source File: GroupOr.scala From mimir with Apache License 2.0 | 5 votes |
package mimir.exec.spark.udf import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{ DataType, BooleanType } import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, Or } case class GroupOr(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate { override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil override def nullable: Boolean = false // Return data type. override def dataType: DataType = BooleanType override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function group_or") private lazy val group_or = AttributeReference("group_or", BooleanType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = group_or :: Nil override lazy val initialValues: Seq[Literal] = Seq( Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq( Or(group_or, child) ) override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = { Seq( Or(group_or.left, group_or.right) ) } override lazy val evaluateExpression: AttributeReference = group_or }
Example 84
Source File: GroupAnd.scala From mimir with Apache License 2.0 | 5 votes |
package mimir.exec.spark.udf import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{ DataType, BooleanType } import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, And } case class GroupAnd(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate { override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil override def nullable: Boolean = false // Return data type. override def dataType: DataType = BooleanType override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function group_and") private lazy val group_and = AttributeReference("group_and", BooleanType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = group_and :: Nil override lazy val initialValues: Seq[Literal] = Seq( Literal.create(true, BooleanType) ) override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq( And(group_and, child) ) override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = { Seq( And(group_and.left, group_and.right) ) } override lazy val evaluateExpression: AttributeReference = group_and }
Example 85
Source File: GroupBitwiseAnd.scala From mimir with Apache License 2.0 | 5 votes |
package mimir.exec.spark.udf import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{ DataType, LongType } import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, BitwiseAnd } case class GroupBitwiseAnd(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate { override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function group_bitwise_and") private lazy val group_bitwise_and = AttributeReference("group_bitwise_and", LongType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = group_bitwise_and :: Nil override lazy val initialValues: Seq[Literal] = Seq( Literal.create(0xffffffffffffffffl, LongType) ) override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq( BitwiseAnd(group_bitwise_and, child) ) override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = { Seq( BitwiseAnd(group_bitwise_and.left, group_bitwise_and.right) ) } override lazy val evaluateExpression: AttributeReference = group_bitwise_and }
Example 86
Source File: GroupBitwiseOr.scala From mimir with Apache License 2.0 | 5 votes |
package mimir.exec.spark.udf import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{ DataType, LongType } import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, BitwiseOr } case class GroupBitwiseOr(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate { override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function group_bitwise_or") private lazy val group_bitwise_or = AttributeReference("group_bitwise_or", LongType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = group_bitwise_or :: Nil override lazy val initialValues: Seq[Literal] = Seq( Literal.create(0, LongType) ) override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq( BitwiseOr(group_bitwise_or, child) ) override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = { Seq( BitwiseOr(group_bitwise_or.left, group_bitwise_or.right) ) } override lazy val evaluateExpression: AttributeReference = group_bitwise_or }
Example 87
Source File: StarryLocalRelation.scala From starry with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.logical import org.apache.spark.sql.catalyst.{InternalRow, analysis} import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} override final def newInstance(): this.type = { LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type] } override protected def stringArgs: Iterator[Any] = { if (data.isEmpty) { Iterator("<empty>", output) } else { Iterator(output) } } override def computeStats(): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) val types = output.map(_.dataType) val rows = data.map { row => val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } cells.mkString("(", ", ", ")") } "VALUES " + rows.mkString(", ") + " AS " + inlineTableName + output.map(_.name).mkString("(", ", ", ")") } }
Example 88
Source File: PullTupleCreator.scala From ingraph with Eclipse Public License 1.0 | 5 votes |
package ingraph.ire.adapters.tuplecreators import ingraph.ire.inputs.InputTransaction import ingraph.ire.{IdParser, Indexer, IngraphEdge} import ingraph.model.fplan.{GetEdges, GetVertices} import org.apache.spark.sql.catalyst.expressions.Literal class PullTupleCreator(vertexOps: Seq[GetVertices], edgeOps: Seq[GetEdges], indexer: Indexer, inputTransaction: InputTransaction, idParser: IdParser ) { for (op <- vertexOps) { val v = op.nnode.v val opLabels = v.labels.vertexLabels val vertices = v.properties.get(TupleConstants.ID_KEY) match { case None => indexer.verticesByLabel(opLabels.head).filter(v => opLabels.subsetOf(v.labels)) case Some(Literal(id, _)) => val label = v.labels.vertexLabels.head val vertex = indexer.vertexByIdLabel( id.asInstanceOf[Long], label ).getOrElse(throw new IllegalStateException(s"Vertex not found with label ${label} and id ${id}")) assert(opLabels.subsetOf(vertex.labels), "Wrong labels on direct delete") Seq(vertex) } for (vertex <- vertices) { val tuple = VertexTransformer(vertex, op, idParser) inputTransaction.add(v.name, tuple) } } for (operator <- edgeOps) { val sourceLabels = operator.nnode.src.labels.vertexLabels val targetLabels = operator.nnode.trg.labels.vertexLabels val labels = operator.nnode.edge.labels.edgeLabels val edges: Iterable[IngraphEdge] = operator.nnode.edge.properties.get(TupleConstants.ID_KEY) match { case None => val src = operator.src.properties.get(TupleConstants.ID_KEY) val dst = operator.trg.properties.get(TupleConstants.ID_KEY) val unfiltered = (src, dst) match { case (Some(Literal(srcId,_)), _) => indexer.vertexLookup(srcId.asInstanceOf[Long]).edgesOut.filter(e => labels.contains(e.`type`)) case (_, Some(Literal(dstId, _))) => indexer.vertexLookup(dstId.asInstanceOf[Long]).edgesIn.filter(e => labels.contains(e.`type`)) case _ => labels.flatMap(label => indexer.edgesByType(label)) } unfiltered.filter(e => sourceLabels.subsetOf(e.sourceVertex.labels) && targetLabels.subsetOf(e.targetVertex.labels)) case Some(Literal(id, _)) => val edge = indexer.edgeById(id.asInstanceOf[Long]).get assert(sourceLabels.subsetOf(edge.sourceVertex.labels) && targetLabels.subsetOf(edge.targetVertex.labels), "Wrong vertex labels on direct delete") assert(edge.`type` == labels.head, "Wrong edge type on direct delete") Seq(edge) } val operatorString = operator.toString() for (edge <- edges) { val tuple = EdgeTransformer(edge, operator, idParser) inputTransaction.add(operatorString, tuple) if (!operator.nnode.directed) { val rTuple = EdgeTransformer( edge.copy(sourceVertex = edge.targetVertex, targetVertex = edge.sourceVertex), operator, idParser) inputTransaction.add(operatorString, rTuple) } } } }
Example 89
Source File: monotonicaggregates.scala From BigDatalog with Apache License 2.0 | 5 votes |
package edu.ucla.cs.wis.bigdatalog.spark.execution.aggregates import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, Greatest, Least, Literal, Unevaluable} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} abstract class MonotonicAggregateFunction extends DeclarativeAggregate with Serializable {} case class MMax(child: Expression) extends MonotonicAggregateFunction { override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true // Return data type. override def dataType: DataType = child.dataType // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function mmax") private lazy val mmax = AttributeReference("mmax", child.dataType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = mmax :: Nil override lazy val initialValues: Seq[Literal] = Seq( Least(Seq(mmin.left, mmin.right)) ) } override lazy val evaluateExpression: AttributeReference = mmin } case class MonotonicAggregateExpression(aggregateFunction: MonotonicAggregateFunction, mode: AggregateMode, isDistinct: Boolean) extends Expression with Unevaluable { override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType override def foldable: Boolean = false override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { val childReferences = mode match { case Partial | Complete => aggregateFunction.references.toSeq case PartialMerge | Final => aggregateFunction.aggBufferAttributes } AttributeSet(childReferences) } override def prettyString: String = aggregateFunction.prettyString override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" }
Example 90
Source File: RuleExecutorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } }
Example 91
Source File: PartitioningSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} class PartitioningSuite extends SparkFunSuite { test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { val expressions = Seq(Literal(2), Literal(3)) // Consider two HashPartitionings that have the same _set_ of hash expressions but which are // created with different orderings of those expressions: val partitioningA = HashPartitioning(expressions, 100) val partitioningB = HashPartitioning(expressions.reverse, 100) // These partitionings are not considered equal: assert(partitioningA != partitioningB) // However, they both satisfy the same clustered distribution: val distribution = ClusteredDistribution(expressions) assert(partitioningA.satisfies(distribution)) assert(partitioningB.satisfies(distribution)) // These partitionings compute different hashcodes for the same input row: def computeHashCode(partitioning: HashPartitioning): Int = { val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) hashExprProj.apply(InternalRow.empty).hashCode() } assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) // Thus, these partitionings are incompatible: assert(!partitioningA.compatibleWith(partitioningB)) assert(!partitioningB.compatibleWith(partitioningA)) assert(!partitioningA.guarantees(partitioningB)) assert(!partitioningB.guarantees(partitioningA)) // Just to be sure that we haven't cheated by having these methods always return false, // check that identical partitionings are still compatible with and guarantee each other: assert(partitioningA === partitioningA) assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } }
Example 92
Source File: AggregateOptimizeSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Nil } test("replace distinct with aggregate") { val input = LocalRelation('a.int, 'b.int) val query = Distinct(input) val optimized = Optimize.execute(query.analyze) val correctAnswer = Aggregate(input.output, input.output, input) comparePlans(optimized, correctAnswer) } test("remove literals in grouping expression") { val input = LocalRelation('a.int, 'b.int) val query = input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(query) val correctAnswer = input.groupBy('a)(sum('b)) comparePlans(optimized, correctAnswer) } }
Example 93
Source File: ExpandSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.IntegerType class ExpandSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder private def testExpand(f: SparkPlan => SparkPlan): Unit = { val input = (1 to 1000).map(Tuple1.apply) val projections = Seq.tabulate(2) { i => Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil } val attributes = projections.head.map(_.toAttribute) checkAnswer( input.toDF(), plan => Expand(projections, attributes, f(plan)), input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) ) } test("inheriting child row type") { val exprs = AttributeReference("a", IntegerType, false)() :: Nil val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") } test("expanding UnsafeRows") { testExpand(ConvertToUnsafe) } test("expanding SafeRows") { testExpand(identity) } }
Example 94
Source File: ExtraStrategiesSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package test.org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.UTF8String case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) sparkContext.parallelize(Seq(row)) } override def children: Seq[SparkPlan] = Nil } object TestStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Project(Seq(attr), _) if attr.name == "a" => FastOperator(attr.toAttribute :: Nil) :: Nil case _ => Nil } } class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("insert an extraStrategy") { try { sqlContext.experimental.extraStrategies = TestStrategy :: Nil val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( df.select("a"), Row("so fast")) checkAnswer( df.select("a", "b"), Row("so slow", 1)) } finally { sqlContext.experimental.extraStrategies = Nil } } }
Example 95
Source File: PostgresIntegrationSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.jdbc import java.sql.Connection import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{Literal, If} import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { override val imageName = "postgres:9.4.5" override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) override val jdbcPort = 5432 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" } override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE DATABASE foo").executeUpdate() conn.setCatalog("foo") conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " + "c10 integer[], c11 text[], c12 real[])").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate() } test("Type mapping for various types") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) assert(types.length == 13) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) assert(classOf[java.lang.Long].isAssignableFrom(types(3))) assert(classOf[java.lang.Boolean].isAssignableFrom(types(4))) assert(classOf[Array[Byte]].isAssignableFrom(types(5))) assert(classOf[Array[Byte]].isAssignableFrom(types(6))) assert(classOf[java.lang.Boolean].isAssignableFrom(types(7))) assert(classOf[String].isAssignableFrom(types(8))) assert(classOf[String].isAssignableFrom(types(9))) assert(classOf[Seq[Int]].isAssignableFrom(types(10))) assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(classOf[Seq[Double]].isAssignableFrom(types(12))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) assert(rows(0).getLong(3) == 123456789012345L) assert(rows(0).getBoolean(4) == false) // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) assert(rows(0).getBoolean(7) == true) assert(rows(0).getString(8) == "172.16.0.42") assert(rows(0).getString(9) == "192.168.0.0/16") assert(rows(0).getSeq(10) == Seq(1, 2)) assert(rows(0).getSeq(11) == Seq("a", null, "b")) assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) } test("Basic write test") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) // Test only that it doesn't crash. df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) // Test write null values. df.select(df.queryExecution.analyzed.output.map { a => Column(Literal.create(null, a.dataType)).as(a.name) }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) } }
Example 96
Source File: SqlExtensionProviderSuite.scala From glow with Apache License 2.0 | 5 votes |
package io.projectglow.sql import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Literal, UnaryExpression} import org.apache.spark.sql.types.{DataType, IntegerType} import io.projectglow.GlowSuite class SqlExtensionProviderSuite extends GlowSuite { override def beforeAll(): Unit = { super.beforeAll() SqlExtensionProvider.registerFunctions( spark.sessionState.conf, spark.sessionState.functionRegistry, "test-functions.yml") } private lazy val sess = spark test("one arg function") { import sess.implicits._ assert(spark.range(1).selectExpr("one_arg_test(id)").as[Int].head() == 1) intercept[AnalysisException] { spark.range(1).selectExpr("one_arg_test()").collect() } intercept[AnalysisException] { spark.range(1).selectExpr("one_arg_test(id, id)").collect() } } test("two arg function") { import sess.implicits._ assert(spark.range(1).selectExpr("two_arg_test(id, id)").as[Int].head() == 1) intercept[AnalysisException] { spark.range(1).selectExpr("two_arg_test(id)").collect() } intercept[AnalysisException] { spark.range(1).selectExpr("two_arg_test(id, id, id)").collect() } } test("var args function") { import sess.implicits._ assert(spark.range(1).selectExpr("var_args_test(id, id)").as[Int].head() == 1) assert(spark.range(1).selectExpr("var_args_test(id, id, id, id)").as[Int].head() == 1) assert(spark.range(1).selectExpr("var_args_test(id)").as[Int].head() == 1) intercept[AnalysisException] { spark.range(1).selectExpr("var_args_test()").collect() } } test("can call optional arg function") { import sess.implicits._ assert(spark.range(1).selectExpr("optional_arg_test(id)").as[Int].head() == 1) assert(spark.range(1).selectExpr("optional_arg_test(id, id)").as[Int].head() == 1) intercept[AnalysisException] { spark.range(1).selectExpr("optional_arg_test()").collect() } intercept[AnalysisException] { spark.range(1).selectExpr("optional_arg_test(id, id, id)").collect() } } } trait TestExpr extends Expression with CodegenFallback { override def dataType: DataType = IntegerType override def nullable: Boolean = true override def eval(input: InternalRow): Any = 1 } case class OneArgExpr(child: Expression) extends UnaryExpression with TestExpr case class TwoArgExpr(left: Expression, right: Expression) extends BinaryExpression with TestExpr case class VarArgsExpr(arg: Expression, varArgs: Seq[Expression]) extends TestExpr { override def children: Seq[Expression] = arg +: varArgs } case class OptionalArgExpr(required: Expression, optional: Expression) extends TestExpr { def this(required: Expression) = this(required, Literal(1)) override def children: Seq[Expression] = Seq(required, optional) }
Example 97
Source File: BDJSpark.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.partitioner.MapDPartition import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.{NumberUtil, ShapeUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable import scala.util.Random case class BDJSpark(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def output: Seq[Attribute] = left.output ++ right.output final val num_partitions = simbaSessionState.simbaConf.joinPartitions final val r = NumberUtil.literalToDouble(l) override protected def doExecute(): RDD[InternalRow] = { val tot_rdd = left.execute().map((0, _)).union(right.execute().map((1, _))) val tot_dup_rdd = tot_rdd.flatMap {x => val rand_no = new Random().nextInt(num_partitions) var ans = mutable.ListBuffer[(Int, (Int, InternalRow))]() if (x._1 == 0) { val base = rand_no * num_partitions for (i <- 0 until num_partitions) ans += ((base + i, x)) } else { for (i <- 0 until num_partitions) ans += ((i * num_partitions + rand_no, x)) } ans } val tot_dup_partitioned = MapDPartition(tot_dup_rdd, num_partitions * num_partitions) tot_dup_partitioned.mapPartitions {iter => var left_data = mutable.ListBuffer[(Point, InternalRow)]() var right_data = mutable.ListBuffer[(Point, InternalRow)]() while (iter.hasNext) { val data = iter.next() if (data._2._1 == 0) { val tmp_point = ShapeUtils.getShape(left_key, left.output, data._2._2).asInstanceOf[Point] left_data += ((tmp_point, data._2._2)) } else { val tmp_point = ShapeUtils.getShape(right_key, right.output, data._2._2).asInstanceOf[Point] right_data += ((tmp_point, data._2._2)) } } val joined_ans = mutable.ListBuffer[InternalRow]() left_data.foreach {left => right_data.foreach {right => if (left._1.minDist(right._1) <= r) { joined_ans += new JoinedRow(left._2, right._2) } } } joined_ans.iterator } } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 98
Source File: BKJSpark.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.partitioner.MapDPartition import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.ShapeUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.BoundedPriorityQueue import scala.collection.mutable import scala.util.Random case class BKJSpark(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def output: Seq[Attribute] = left.output ++ right.output final val num_partitions = simbaSessionState.simbaConf.joinPartitions final val k = l.value.asInstanceOf[Number].intValue() private class DisOrdering extends Ordering[(InternalRow, Double)] { override def compare(x : (InternalRow, Double), y: (InternalRow, Double)): Int = -x._2.compare(y._2) } override protected def doExecute(): RDD[InternalRow] = { val tot_rdd = left.execute().map((0, _)).union(right.execute().map((1, _))) val tot_dup_rdd = tot_rdd.flatMap {x => val rand_no = new Random().nextInt(num_partitions) val ans = mutable.ListBuffer[(Int, (Int, InternalRow))]() if (x._1 == 0) { val base = rand_no * num_partitions for (i <- 0 until num_partitions) ans += ((base + i, x)) } else { for (i <- 0 until num_partitions) ans += ((i * num_partitions + rand_no, x)) } ans } val tot_dup_partitioned = MapDPartition(tot_dup_rdd, num_partitions * num_partitions) tot_dup_partitioned.mapPartitions {iter => var left_data = mutable.ListBuffer[(Point, InternalRow)]() var right_data = mutable.ListBuffer[(Point, InternalRow)]() while (iter.hasNext) { val data = iter.next() if (data._2._1 == 0) { val tmp_point = ShapeUtils.getShape(left_key, left.output, data._2._2).asInstanceOf[Point] left_data += ((tmp_point, data._2._2)) } else { val tmp_point = ShapeUtils.getShape(right_key, right.output, data._2._2).asInstanceOf[Point] right_data += ((tmp_point, data._2._2)) } } val joined_ans = mutable.ListBuffer[(InternalRow, Array[(InternalRow, Double)])]() left_data.foreach(left => { var pq = new BoundedPriorityQueue[(InternalRow, Double)](k)(new DisOrdering) right_data.foreach(right => pq += ((right._2, right._1.minDist(left._1)))) joined_ans += ((left._2, pq.toArray)) }) joined_ans.iterator }.reduceByKey((left, right) => (left ++ right).sortWith(_._2 < _._2).take(k), num_partitions) .flatMap { now => now._2.map(x => new JoinedRow(now._1, x._1)) } } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 99
Source File: BKJSparkR.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.index.RTree import org.apache.spark.sql.simba.partitioner.MapDPartition import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.ShapeUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable import scala.util.Random case class BKJSparkR(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def output: Seq[Attribute] = left.output ++ right.output final val num_partitions = simbaSessionState.simbaConf.joinPartitions final val max_entries_per_node = simbaSessionState.simbaConf.maxEntriesPerNode final val k = l.value.asInstanceOf[Number].intValue() private class DisOrdering extends Ordering[(InternalRow, Double)] { override def compare(x : (InternalRow, Double), y: (InternalRow, Double)): Int = -x._2.compare(y._2) } override protected def doExecute(): RDD[InternalRow] = { val tot_rdd = left.execute().map((0, _)).union(right.execute().map((1, _))) val tot_dup_rdd = tot_rdd.flatMap {x => val rand_no = new Random().nextInt(num_partitions) val ans = mutable.ListBuffer[(Int, (Int, InternalRow))]() if (x._1 == 0) { val base = rand_no * num_partitions for (i <- 0 until num_partitions) ans += ((base + i, x)) } else { for (i <- 0 until num_partitions) ans += ((i * num_partitions + rand_no, x)) } ans } val tot_dup_partitioned = MapDPartition(tot_dup_rdd, num_partitions * num_partitions) tot_dup_partitioned.mapPartitions {iter => var left_data = mutable.ListBuffer[(Point, InternalRow)]() var right_data = mutable.ListBuffer[(Point, InternalRow)]() while (iter.hasNext) { val data = iter.next() if (data._2._1 == 0) { val tmp_point = ShapeUtils.getShape(left_key, left.output, data._2._2).asInstanceOf[Point] left_data += ((tmp_point, data._2._2)) } else { val tmp_point = ShapeUtils.getShape(right_key, right.output, data._2._2).asInstanceOf[Point] right_data += ((tmp_point, data._2._2)) } } val joined_ans = mutable.ListBuffer[(InternalRow, Array[(InternalRow, Double)])]() if (right_data.nonEmpty) { val right_rtree = RTree(right_data.map(_._1).zipWithIndex.toArray, max_entries_per_node) left_data.foreach(left => joined_ans += ((left._2, right_rtree.kNN(left._1, k, keepSame = false) .map(x => (right_data(x._2)._2, x._1.minDist(left._1))))) ) } joined_ans.iterator }.reduceByKey((left, right) => (left ++ right).sortWith(_._2 < _._2).take(k), num_partitions) .flatMap { now => now._2.map(x => new JoinedRow(now._1, x._1)) } } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 100
Source File: BDJSparkR.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.index.RTree import org.apache.spark.sql.simba.partitioner.MapDPartition import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.{NumberUtil, ShapeUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable import scala.util.Random case class BDJSparkR(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def output: Seq[Attribute] = left.output ++ right.output final val num_partitions = simbaSessionState.simbaConf.joinPartitions final val r = NumberUtil.literalToDouble(l) final val max_entries_per_node = simbaSessionState.simbaConf.maxEntriesPerNode override protected def doExecute(): RDD[InternalRow] = { val tot_rdd = left.execute().map((0, _)).union(right.execute().map((1, _))) val tot_dup_rdd = tot_rdd.flatMap {x => val rand_no = new Random().nextInt(num_partitions) var ans = mutable.ListBuffer[(Int, (Int, InternalRow))]() if (x._1 == 0) { val base = rand_no * num_partitions for (i <- 0 until num_partitions) ans += ((base + i, x)) } else { for (i <- 0 until num_partitions) ans += ((i * num_partitions + rand_no, x)) } ans } val tot_dup_partitioned = MapDPartition(tot_dup_rdd, num_partitions * num_partitions) tot_dup_partitioned.mapPartitions {iter => var left_data = mutable.ListBuffer[(Point, InternalRow)]() var right_data = mutable.ListBuffer[(Point, InternalRow)]() while (iter.hasNext) { val data = iter.next() if (data._2._1 == 0) { val tmp_point = ShapeUtils.getShape(left_key, left.output, data._2._2).asInstanceOf[Point] left_data += ((tmp_point, data._2._2)) } else { val tmp_point = ShapeUtils.getShape(right_key, right.output, data._2._2).asInstanceOf[Point] right_data += ((tmp_point, data._2._2)) } } val joined_ans = mutable.ListBuffer[InternalRow]() if (right_data.nonEmpty) { val right_rtree = RTree(right_data.map(_._1).zipWithIndex.toArray, max_entries_per_node) left_data.foreach(left => right_rtree.circleRange(left._1, r) .foreach(x => joined_ans += new JoinedRow(left._2, right_data(x._2)._2))) } joined_ans.iterator } } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 101
Source File: DJSpark.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.index.RTree import org.apache.spark.sql.simba.partitioner.{MapDPartition, STRPartition} import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.{NumberUtil, ShapeUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable case class DJSpark(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def output: Seq[Attribute] = left.output ++ right.output final val num_partitions = simbaSessionState.simbaConf.joinPartitions final val sample_rate = simbaSessionState.simbaConf.sampleRate final val max_entries_per_node = simbaSessionState.simbaConf.maxEntriesPerNode final val transfer_threshold = simbaSessionState.simbaConf.transferThreshold final val r = NumberUtil.literalToDouble(l) override protected def doExecute(): RDD[InternalRow] = { val left_rdd = left.execute().map(row => (ShapeUtils.getShape(left_key, left.output, row).asInstanceOf[Point], row) ) val right_rdd = right.execute().map(row => (ShapeUtils.getShape(right_key, right.output, row).asInstanceOf[Point], row) ) val dimension = right_rdd.first()._1.coord.length val (left_partitioned, left_mbr_bound) = STRPartition(left_rdd, dimension, num_partitions, sample_rate, transfer_threshold, max_entries_per_node) val (right_partitioned, right_mbr_bound) = STRPartition(right_rdd, dimension, num_partitions, sample_rate, transfer_threshold, max_entries_per_node) val right_rt = RTree(right_mbr_bound.zip(Array.fill[Int](right_mbr_bound.length)(0)) .map(x => (x._1._1, x._1._2, x._2)), max_entries_per_node) val left_dup = new Array[Array[Int]](left_mbr_bound.length) val right_dup = new Array[Array[Int]](right_mbr_bound.length) var tot = 0 left_mbr_bound.foreach { now => val res = right_rt.circleRange(now._1, r) val tmp_arr = mutable.ArrayBuffer[Int]() res.foreach {x => if (right_dup(x._2) == null) right_dup(x._2) = Array(tot) else right_dup(x._2) = right_dup(x._2) :+ tot tmp_arr += tot tot += 1 } left_dup(now._2) = tmp_arr.toArray } val bc_left_dup = sparkContext.broadcast(left_dup) val bc_right_dup = sparkContext.broadcast(right_dup) val left_dup_rdd = left_partitioned.mapPartitionsWithIndex { (id, iter) => iter.flatMap {now => val tmp_list = bc_left_dup.value(id) if (tmp_list != null) tmp_list.map(x => (x, now)) else Array[(Int, (Point, InternalRow))]() } } val right_dup_rdd = right_partitioned.mapPartitionsWithIndex { (id, iter) => iter.flatMap {now => val tmp_list = bc_right_dup.value(id) if (tmp_list != null) tmp_list.map(x => (x, now)) else Array[(Int, (Point, InternalRow))]() } } val left_dup_partitioned = MapDPartition(left_dup_rdd, tot).map(_._2) val right_dup_partitioned = MapDPartition(right_dup_rdd, tot).map(_._2) left_dup_partitioned.zipPartitions(right_dup_partitioned) {(leftIter, rightIter) => val ans = mutable.ListBuffer[InternalRow]() val right_data = rightIter.toArray if (right_data.nonEmpty) { val right_index = RTree(right_data.map(_._1).zipWithIndex, max_entries_per_node) leftIter.foreach {now => ans ++= right_index.circleRange(now._1, r) .map(x => new JoinedRow(now._2, right_data(x._2)._2)) } } ans.iterator } } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 102
Source File: RDJSpark.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.index.RTree import org.apache.spark.sql.simba.partitioner.{MapDPartition, STRPartition} import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.{NumberUtil, ShapeUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable case class RDJSpark(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def output: Seq[Attribute] = left.output ++ right.output final val num_partitions = simbaSessionState.simbaConf.joinPartitions final val sample_rate = simbaSessionState.simbaConf.sampleRate final val max_entries_per_node = simbaSessionState.simbaConf.maxEntriesPerNode final val transfer_threshold = simbaSessionState.simbaConf.transferThreshold final val r = NumberUtil.literalToDouble(l) override protected def doExecute(): RDD[InternalRow] = { val left_rdd = left.execute().map(row => (ShapeUtils.getShape(left_key, left.output, row).asInstanceOf[Point], row) ) val right_rdd = right.execute().map(row => (ShapeUtils.getShape(right_key, right.output, row).asInstanceOf[Point], row) ) val dimension = right_rdd.first()._1.coord.length val (left_partitioned, left_mbr_bound) = STRPartition(left_rdd, dimension, num_partitions, sample_rate, transfer_threshold, max_entries_per_node) val left_part_size = left_partitioned.mapPartitions { iter => Array(iter.length).iterator }.collect() val left_rt = RTree(left_mbr_bound.zip(left_part_size).map(x => (x._1._1, x._1._2, x._2)), max_entries_per_node) val bc_rt = sparkContext.broadcast(left_rt) val right_dup = right_rdd.flatMap {x => bc_rt.value.circleRange(x._1, r).map(now => (now._2, x)) } val right_dup_partitioned = MapDPartition(right_dup, left_mbr_bound.length) left_partitioned.zipPartitions(right_dup_partitioned) {(leftIter, rightIter) => val ans = mutable.ListBuffer[InternalRow]() val right_data = rightIter.map(_._2).toArray if (right_data.length > 0) { val right_index = RTree(right_data.map(_._1).zipWithIndex, max_entries_per_node) leftIter.foreach {now => ans ++= right_index.circleRange(now._1, r) .map(x => new JoinedRow(now._2, right_data(x._2)._2)) } } ans.iterator } } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 103
Source File: CKJSpark.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.ShapeUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan case class CKJSpark(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def outputPartitioning: Partitioning = left.outputPartitioning override def output: Seq[Attribute] = left.output ++ right.output final val k = l.value.asInstanceOf[Number].intValue() override protected def doExecute(): RDD[InternalRow] = { val left_rdd = left.execute() val right_rdd = right.execute() left_rdd.map(row => (ShapeUtils.getShape(left_key, left.output, row).asInstanceOf[Point], row) ).cartesian(right_rdd).map { case (l: (Point, InternalRow), r: InternalRow) => val tmp_point = ShapeUtils.getShape(right_key, right.output, r).asInstanceOf[Point] l._2 -> List((tmp_point.minDist(l._1), r)) }.reduceByKey { case (l_list: Seq[(Double, InternalRow)], r_list: Seq[(Double, InternalRow)]) => (l_list ++ r_list).sortWith(_._1 < _._1).take(k) }.flatMapValues(list => list).mapPartitions { iter => val joinedRow = new JoinedRow iter.map(r => joinedRow(r._1, r._2._2)) } } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 104
Source File: CDJSpark.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.{NumberUtil, ShapeUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan case class CDJSpark(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SparkPlan { override def outputPartitioning: Partitioning = left.outputPartitioning override def output: Seq[Attribute] = left.output ++ right.output final val r = NumberUtil.literalToDouble(l) override protected def doExecute(): RDD[InternalRow] = left.execute().cartesian(right.execute()).mapPartitions { iter => val joinedRow = new JoinedRow iter.filter { row => val point1 = ShapeUtils.getShape(left_key, left.output, row._1).asInstanceOf[Point] val point2 = ShapeUtils.getShape(right_key, right.output, row._2).asInstanceOf[Point] point1.minDist(point2) <= r }.map(row => joinedRow(row._1, row._2)) } override def children: Seq[SparkPlan] = Seq(left, right) }
Example 105
Source File: FilterExec.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.execution import org.apache.spark.sql.simba.expression._ import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.ShapeUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal, PredicateHelper} import org.apache.spark.sql.catalyst.expressions.{SortOrder, And => SQLAnd, Not => SQLNot, Or => SQLOr} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan case class FilterExec(condition: Expression, child: SparkPlan) extends SimbaPlan with PredicateHelper { override def output: Seq[Attribute] = child.output private class DistanceOrdering(point: Expression, target: Point) extends Ordering[InternalRow] { override def compare(x: InternalRow, y: InternalRow): Int = { val shape_x = ShapeUtils.getShape(point, child.output, x) val shape_y = ShapeUtils.getShape(point, child.output, y) val dis_x = target.minDist(shape_x) val dis_y = target.minDist(shape_y) dis_x.compare(dis_y) } } // TODO change target partition from 1 to some good value // Note that target here must be an point literal in WHERE clause, // hence we can consider it as Point safely def knn(rdd: RDD[InternalRow], point: Expression, target: Point, k: Int): RDD[InternalRow] = sparkContext.parallelize(rdd.map(_.copy()).takeOrdered(k)(new DistanceOrdering(point, target)), 1) def applyCondition(rdd: RDD[InternalRow], condition: Expression): RDD[InternalRow] = { condition match { case InKNN(point, target, k) => val _target = target.asInstanceOf[Literal].value.asInstanceOf[Point] knn(rdd, point, _target, k.value.asInstanceOf[Number].intValue()) case now@And(left, right) => if (!now.hasKNN) rdd.mapPartitions{ iter => iter.filter(newPredicate(condition, child.output).eval(_))} else applyCondition(rdd, left).map(_.copy()).intersection(applyCondition(rdd, right).map(_.copy())) case now@Or(left, right) => if (!now.hasKNN) rdd.mapPartitions{ iter => iter.filter(newPredicate(condition, child.output).eval(_))} else applyCondition(rdd, left).map(_.copy()).union(applyCondition(rdd, right).map(_.copy())).distinct() case now@Not(c) => if (!now.hasKNN) rdd.mapPartitions{ iter => iter.filter(newPredicate(condition, child.output).eval(_))} else rdd.map(_.copy()).subtract(applyCondition(rdd, c).map(_.copy())) case _ => rdd.mapPartitions(iter => iter.filter(newPredicate(condition, child.output).eval(_))) } } protected def doExecute(): RDD[InternalRow] = { val root_rdd = child.execute() condition transformUp { case SQLAnd(left, right) => And(left, right) case SQLOr(left, right)=> Or(left, right) case SQLNot(c) => Not(c) } applyCondition(root_rdd, condition) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def children: Seq[SparkPlan] = child :: Nil override def outputPartitioning: Partitioning = child.outputPartitioning }
Example 106
Source File: InRange.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.expression import org.apache.spark.sql.simba.{ShapeSerializer, ShapeType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Predicate} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.simba.spatial.{MBR, Point, Shape} import org.apache.spark.sql.simba.util.ShapeUtils import org.apache.spark.sql.catalyst.util.GenericArrayData case class InRange(shape: Expression, range_low: Expression, range_high: Expression) extends Predicate with CodegenFallback{ override def nullable: Boolean = false override def eval(input: InternalRow): Any = { val eval_shape = ShapeUtils.getShape(shape, input) val eval_low = range_low.asInstanceOf[Literal].value.asInstanceOf[Point] val eval_high = range_high.asInstanceOf[Literal].value.asInstanceOf[Point] require(eval_shape.dimensions == eval_low.dimensions && eval_shape.dimensions == eval_high.dimensions) val mbr = MBR(eval_low, eval_high) mbr.intersects(eval_shape) } override def toString: String = s" **($shape) IN Rectangle ($range_low) - ($range_high)** " override def children: Seq[Expression] = Seq(shape, range_low, range_high) }
Example 107
Source File: NumberUtil.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.util import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ object NumberUtil { def literalToDouble(x: Literal): Double = { x.value match { case double_value: Number => double_value.doubleValue() case decimal_value: Decimal => decimal_value.toDouble } } def isIntegral(x: DataType): Boolean = { x match { case IntegerType => true case LongType => true case ShortType => true case ByteType => true case _ => false } } }
Example 108
Source File: ParamBinder.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server.service import java.sql.SQLException import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder object ParamBinder { def bind(logicalPlan: LogicalPlan, params: Map[Int, Literal]): LogicalPlan = { val boundPlan = logicalPlan.transformAllExpressions { case ParameterPlaceHolder(id) if params.contains(id) => params(id) } val unresolvedParams = boundPlan.flatMap { plan => plan.expressions.flatMap { _.flatMap { case ParameterPlaceHolder(id) => Some(id) case _ => None }} } if (unresolvedParams.nonEmpty) { throw new SQLException("Unresolved parameters found: " + unresolvedParams.map(n => s"$$$n").mkString(", ")) } boundPlan } }
Example 109
Source File: OperationManager.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server.service import javax.annotation.concurrent.ThreadSafe import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.server.SQLServerConf._ import org.apache.spark.sql.server.SQLServerEnv import org.apache.spark.sql.types.StructType sealed trait OperationState case object INITIALIZED extends OperationState case object RUNNING extends OperationState case object FINISHED extends OperationState case object CANCELED extends OperationState case object CLOSED extends OperationState case object ERROR extends OperationState case object UNKNOWN extends OperationState case object PENDING extends OperationState sealed trait OperationType { override def toString: String = getClass.getSimpleName.stripSuffix("$") } object BEGIN extends OperationType object FETCH extends OperationType object SELECT extends OperationType @ThreadSafe trait Operation { private val timeout = SQLServerEnv.sqlConf.sqlServerIdleOperationTimeout private var lastAccessTime: Long = System.currentTimeMillis() @volatile protected var state: OperationState = INITIALIZED def statementId(): String def outputSchema(): StructType def prepare(params: Map[Int, Literal]): Unit def run(): Iterator[InternalRow] def cancel(): Unit def close(): Unit protected def setState(newState: OperationState): Unit = { lastAccessTime = System.currentTimeMillis() state = newState } def isTimeOut(current: Long): Boolean = { if (timeout == 0) { true } else if (timeout > 0) { Seq(FINISHED, CANCELED, CLOSED, ERROR).contains(state) && lastAccessTime + timeout <= current } else { lastAccessTime - timeout <= current } } } object NOP extends Operation { override val statementId: String = "nop" override val outputSchema: StructType = new StructType() override def prepare(params: Map[Int, Literal]): Unit = {} override def run(): Iterator[InternalRow] = Iterator.empty override def cancel(): Unit = {} override def close(): Unit = {} } trait OperationExecutor { // Creates a new instance for service-specific operations def newOperation( sessionState: SessionState, statementId: String, query: (String, LogicalPlan)): Operation }
Example 110
Source File: ParameterBinderSuite.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server.service.postgresql.protocol.v3 import java.sql.SQLException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder import org.apache.spark.sql.server.service.ParamBinder import org.apache.spark.sql.types._ class ParameterBinderSuite extends PlanTest { test("bind parameters") { val c0 = 'a.int val c1 = 'b.int val r1 = LocalRelation(c0, c1) val param1 = Literal(18, IntegerType) val lp1 = Filter(EqualTo(c0, ParameterPlaceHolder(1)), r1) val expected1 = Filter(EqualTo(c0, param1), r1) comparePlans(expected1, ParamBinder.bind(lp1, Map(1 -> param1))) val param2 = Literal(42, IntegerType) val lp2 = Filter(EqualTo(c0, ParameterPlaceHolder(300)), r1) val expected2 = Filter(EqualTo(c0, param2), r1) comparePlans(expected2, ParamBinder.bind(lp2, Map(300 -> param2))) val param3 = Literal(-1, IntegerType) val param4 = Literal(48, IntegerType) val lp3 = Filter( And( EqualTo(c0, ParameterPlaceHolder(1)), EqualTo(c1, ParameterPlaceHolder(2)) ), r1) val expected3 = Filter( And( EqualTo(c0, param3), EqualTo(c1, param4) ), r1) comparePlans(expected3, ParamBinder.bind(lp3, Map(1 -> param3, 2 -> param4))) val errMsg1 = intercept[SQLException] { ParamBinder.bind(lp1, Map.empty) }.getMessage assert(errMsg1 == "Unresolved parameters found: $1") val errMsg2 = intercept[SQLException] { ParamBinder.bind(lp2, Map.empty) }.getMessage assert(errMsg2 == "Unresolved parameters found: $300") val errMsg3 = intercept[SQLException] { ParamBinder.bind(lp3, Map.empty) }.getMessage assert(errMsg3 == "Unresolved parameters found: $1, $2") } }