org.apache.spark.sql.catalyst.expressions.And Scala Examples
The following examples show how to use org.apache.spark.sql.catalyst.expressions.And.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: DeltaPushFilter.scala From connectors with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.delta import scala.collection.immutable.HashSet import scala.collection.JavaConverters._ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, SerializationUtilities} import org.apache.hadoop.hive.ql.lib._ import org.apache.hadoop.hive.ql.parse.SemanticException import org.apache.hadoop.hive.ql.plan.{ExprNodeColumnDesc, ExprNodeConstantDesc, ExprNodeGenericFuncDesc} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, InSet, LessThan, LessThanOrEqual, Like, Literal, Not} object DeltaPushFilter extends Logging { lazy val supportedPushDownUDFs = Array( "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualNS", "org.apache.hadoop.hive.ql.udf.UDFLike", "org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn" ) def partitionFilterConverter(hiveFilterExprSeriablized: String): Seq[Expression] = { if (hiveFilterExprSeriablized != null) { val filterExpr = SerializationUtilities.deserializeExpression(hiveFilterExprSeriablized) val opRules = new java.util.LinkedHashMap[Rule, NodeProcessor]() val nodeProcessor = new NodeProcessor() { @throws[SemanticException] def process(nd: Node, stack: java.util.Stack[Node], procCtx: NodeProcessorCtx, nodeOutputs: Object*): Object = { nd match { case e: ExprNodeGenericFuncDesc if FunctionRegistry.isOpAnd(e) => nodeOutputs.map(_.asInstanceOf[Expression]).reduce(And) case e: ExprNodeGenericFuncDesc => val (columnDesc, constantDesc) = if (nd.getChildren.get(0).isInstanceOf[ExprNodeColumnDesc]) { (nd.getChildren.get(0), nd.getChildren.get(1)) } else { (nd.getChildren.get(1), nd.getChildren.get(0)) } val columnAttr = UnresolvedAttribute( columnDesc.asInstanceOf[ExprNodeColumnDesc].getColumn) val constantVal = Literal(constantDesc.asInstanceOf[ExprNodeConstantDesc].getValue) nd.asInstanceOf[ExprNodeGenericFuncDesc].getGenericUDF match { case f: GenericUDFOPNotEqualNS => Not(EqualNullSafe(columnAttr, constantVal)) case f: GenericUDFOPNotEqual => Not(EqualTo(columnAttr, constantVal)) case f: GenericUDFOPEqualNS => EqualNullSafe(columnAttr, constantVal) case f: GenericUDFOPEqual => EqualTo(columnAttr, constantVal) case f: GenericUDFOPGreaterThan => GreaterThan(columnAttr, constantVal) case f: GenericUDFOPEqualOrGreaterThan => GreaterThanOrEqual(columnAttr, constantVal) case f: GenericUDFOPLessThan => LessThan(columnAttr, constantVal) case f: GenericUDFOPEqualOrLessThan => LessThanOrEqual(columnAttr, constantVal) case f: GenericUDFBridge if f.getUdfName.equals("like") => Like(columnAttr, constantVal) case f: GenericUDFIn => val inConstantVals = nd.getChildren.asScala .filter(_.isInstanceOf[ExprNodeConstantDesc]) .map(_.asInstanceOf[ExprNodeConstantDesc].getValue) .map(Literal(_)).toSet InSet(columnAttr, HashSet() ++ inConstantVals) case _ => throw new RuntimeException(s"Unsupported func(${nd.getName}) " + s"which can not be pushed down to delta") } case _ => null } } } val disp = new DefaultRuleDispatcher(nodeProcessor, opRules, null) val ogw = new DefaultGraphWalker(disp) val topNodes = new java.util.ArrayList[Node]() topNodes.add(filterExpr) val nodeOutput = new java.util.HashMap[Node, Object]() try { ogw.startWalking(topNodes, nodeOutput) } catch { case ex: Exception => throw new RuntimeException(ex) } logInfo(s"converted partition filter expr:" + s"${nodeOutput.get(filterExpr).asInstanceOf[Expression].toJSON}") Seq(nodeOutput.get(filterExpr).asInstanceOf[Expression]) } else Seq.empty[org.apache.spark.sql.catalyst.expressions.Expression] } }
Example 2
Source File: DataSourceV2Strategy.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { // TODO: nested column pruning. private def pruneColumns( reader: DataSourceReader, relation: DataSourceV2Relation, exprs: Seq[Expression]): Seq[AttributeReference] = { reader match { case r: SupportsPushDownRequiredColumns => val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) if (neededOutput != relation.output) { r.pruneColumns(neededOutput.toStructType) val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap r.readSchema().toAttributes.map { // We have to keep the attribute id during transformation. a => a.withExprId(nameToAttr(a.name).exprId) } } else { relation.output } case _ => relation.output } } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val reader = relation.newReader() // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. val (pushedFilters, postScanFilters) = pushFilters(reader, filters) val output = pruneColumns(reader, relation, project ++ postScanFilters) logInfo( s""" |Pushing operators to ${relation.source.getClass} |Pushed Filters: ${pushedFilters.mkString(", ")} |Post-Scan Filters: ${postScanFilters.mkString(",")} |Output: ${output.mkString(", ")} """.stripMargin) val scan = DataSourceV2ScanExec( output, relation.source, relation.options, pushedFilters, reader) val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) // always add the projection, which will produce unsafe rows required by some operators ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil case Repartition(1, false, child) => val isContinuous = child.collectFirst { case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r }.isDefined if (isContinuous) { ContinuousCoalesceExec(1, planLater(child)) :: Nil } else { Nil } case _ => Nil } }
Example 3
Source File: SemiJoinSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} //半连接测试套件 class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { private lazy val left = ctx.createDataFrame( ctx.sparkContext.parallelize(Seq( Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) private lazy val right = ctx.createDataFrame( ctx.sparkContext.parallelize(Seq( Row(2, 3.0), Row(2, 3.0), Row(3, 2.0), Row(4, 1.0), Row(null, null), Row(null, 5.0), Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val condition = { And((left.col("a") === right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } // Note: the input dataframes and expression must be evaluated lazily because // the SQLContext should be used only within a test to keep SQL tests stable private def testLeftSemiJoin( testName: String, leftRows: => DataFrame, rightRows: => DataFrame, condition: => Expression, expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) ExtractEquiJoinKeys.unapply(join) } test(s"$testName using LeftSemiJoinHash") { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext).apply( LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } test(s"$testName using BroadcastLeftSemiJoinHash") { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } test(s"$testName using LeftSemiJoinBNL") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => LeftSemiJoinBNL(left, right, Some(condition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } //测试左半连接 testLeftSemiJoin( "basic test", left, right, condition, Seq( (2, 1.0), (2, 1.0) ) ) }
Example 4
Source File: BatchEvalPythonExecSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.{PythonEvalType, PythonFunction} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.newProductEncoder import testImplicits.localSeqToDatasetHolder override def beforeAll(): Unit = { super.beforeAll() spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF) } override def afterAll(): Unit = { spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) super.afterAll() } test("Python UDF: push down deterministic FilterExec predicates") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: AttributeReference), InputAdapter(_: BatchEvalPythonExec)) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b } assert(qualifiedPlanNodes.size == 2) } test("Nested Python UDF: push down deterministic FilterExec predicates") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b } assert(qualifiedPlanNodes.size == 2) } test("Python UDF: no push down on non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } assert(qualifiedPlanNodes.size == 2) } test("Python UDF: push down on deterministic predicates after the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } assert(qualifiedPlanNodes.size == 2) } test("Python UDF refers to the attributes from more than one child") { val df = Seq(("Hello", 4)).toDF("a", "b") val df2 = Seq(("Hello", 4)).toDF("c", "d") val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect { case b: BatchEvalPythonExec => b } assert(qualifiedPlanNodes.size == 1) } } // This Python UDF is dummy and just for testing. Unable to execute. class DummyUDF extends PythonFunction( command = Array[Byte](), envVars = Map("" -> "").asJava, pythonIncludes = ArrayBuffer("").asJava, pythonExec = "", pythonVer = "", broadcastVars = null, accumulator = null) class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true)
Example 5
Source File: GroupAnd.scala From mimir with Apache License 2.0 | 5 votes |
package mimir.exec.spark.udf import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{ DataType, BooleanType } import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, And } case class GroupAnd(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate { override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil override def nullable: Boolean = false // Return data type. override def dataType: DataType = BooleanType override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function group_and") private lazy val group_and = AttributeReference("group_and", BooleanType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = group_and :: Nil override lazy val initialValues: Seq[Literal] = Seq( Literal.create(true, BooleanType) ) override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq( And(group_and, child) ) override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = { Seq( And(group_and.left, group_and.right) ) } override lazy val evaluateExpression: AttributeReference = group_and }
Example 6
Source File: SemiJoinSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { private lazy val left = sqlContext.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) private lazy val right = sqlContext.createDataFrame( sparkContext.parallelize(Seq( Row(2, 3.0), Row(2, 3.0), Row(3, 2.0), Row(4, 1.0), Row(null, null), Row(null, 5.0), Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val condition = { And((left.col("a") === right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } // Note: the input dataframes and expression must be evaluated lazily because // the SQLContext should be used only within a test to keep SQL tests stable private def testLeftSemiJoin( testName: String, leftRows: => DataFrame, rightRows: => DataFrame, condition: => Expression, expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) ExtractEquiJoinKeys.unapply(join) } test(s"$testName using LeftSemiJoinHash") { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext).apply( LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } test(s"$testName using BroadcastLeftSemiJoinHash") { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } test(s"$testName using LeftSemiJoinBNL") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => LeftSemiJoinBNL(left, right, Some(condition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } testLeftSemiJoin( "basic test", left, right, condition, Seq( (2, 1.0), (2, 1.0) ) ) }
Example 7
Source File: SimbaOptimizer.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkOptimizer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.simba.plans.SpatialJoin class SimbaOptimizer(catalog: SessionCatalog, conf: SQLConf, experimentalMethods: ExperimentalMethods) extends SparkOptimizer(catalog, conf, experimentalMethods) { override def batches: Seq[Batch] = super.batches :+ Batch("SpatialJoinPushDown", FixedPoint(100), PushPredicateThroughSpatialJoin) } object PushPredicateThroughSpatialJoin extends Rule[LogicalPlan] with PredicateHelper { private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = condition.partition(_.references subsetOf left.outputSet) val (rightEvaluateCondition, commonCondition) = rest.partition(_.references subsetOf right.outputSet) (leftEvaluateCondition, rightEvaluateCondition, commonCondition) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter case f @ Filter(filterCondition, SpatialJoin(left, right, joinType, joinCondition)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) val newLeft = leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) SpatialJoin(newLeft, newRight, joinType, newJoinCond) // push down the join filter into sub query scanning if applicable case f @ SpatialJoin(left, right, joinType, joinCondition) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) val newLeft = leftJoinConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = rightJoinConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) SpatialJoin(newLeft, newRight, joinType, newJoinCond) } }
Example 8
Source File: PredicateUtil.scala From Simba with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.simba.util import org.apache.spark.sql.catalyst.expressions.{Expression, And, Or} object PredicateUtil { def toDNF(condition: Expression): Expression = { condition match { case Or(left, right) => Or(toDNF(left), toDNF(right)) case And(left, right) => var ans: Expression = null val tmp_left = toDNF(left) val tmp_right = toDNF(right) tmp_left match { case Or(l, r) => ans = Or(And(l, tmp_right), And(r, tmp_right)) case _ => } tmp_right match { case Or(l, r) => if (ans == null) ans = Or(And(tmp_left, l), And(tmp_left, r)) case _ => } if (ans == null) And(tmp_left, tmp_right) else toDNF(ans) case exp => exp } } def toCNF(condition: Expression): Expression = { condition match { case And(left, right) => And(toCNF(left), toCNF(right)) case Or(left, right) => var ans: Expression = null val tmp_left = toCNF(left) val tmp_right = toCNF(right) tmp_left match { case And(l, r) => ans = And(Or(l, tmp_right), Or(r, tmp_right)) case _ => } tmp_right match { case And(l, r) => if (ans == null) ans = And(Or(tmp_left, l), Or(tmp_left, r)) case _ => } if (ans == null) Or(tmp_left, tmp_right) else toCNF(ans) case exp => exp } } def dnfExtract(expression: Expression): Seq[Expression] = { expression match { case Or(left, right) => dnfExtract(left) ++ dnfExtract(right) case And(left @ And(l2, r2), right) => dnfExtract(And(l2, And(r2, right))) case other => other :: Nil } } def cnfExtract(expression: Expression): Seq[Expression] = { expression match { case And(left, right) => cnfExtract(left) ++ cnfExtract(right) case Or(left @ Or(l2, r2), right) => cnfExtract(Or(l2, Or(r2, right))) case other => other :: Nil } } def splitDNFPredicates(condition: Expression) = dnfExtract(toDNF(condition)) def splitCNFPredicates(condition: Expression) = cnfExtract(toCNF(condition)) def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { case And(cond1, cond2) => splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) case other => other :: Nil } } def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { case Or(cond1, cond2) => splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2) case other => other :: Nil } } }
Example 9
Source File: ParameterBinderSuite.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server.service.postgresql.protocol.v3 import java.sql.SQLException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder import org.apache.spark.sql.server.service.ParamBinder import org.apache.spark.sql.types._ class ParameterBinderSuite extends PlanTest { test("bind parameters") { val c0 = 'a.int val c1 = 'b.int val r1 = LocalRelation(c0, c1) val param1 = Literal(18, IntegerType) val lp1 = Filter(EqualTo(c0, ParameterPlaceHolder(1)), r1) val expected1 = Filter(EqualTo(c0, param1), r1) comparePlans(expected1, ParamBinder.bind(lp1, Map(1 -> param1))) val param2 = Literal(42, IntegerType) val lp2 = Filter(EqualTo(c0, ParameterPlaceHolder(300)), r1) val expected2 = Filter(EqualTo(c0, param2), r1) comparePlans(expected2, ParamBinder.bind(lp2, Map(300 -> param2))) val param3 = Literal(-1, IntegerType) val param4 = Literal(48, IntegerType) val lp3 = Filter( And( EqualTo(c0, ParameterPlaceHolder(1)), EqualTo(c1, ParameterPlaceHolder(2)) ), r1) val expected3 = Filter( And( EqualTo(c0, param3), EqualTo(c1, param4) ), r1) comparePlans(expected3, ParamBinder.bind(lp3, Map(1 -> param3, 2 -> param4))) val errMsg1 = intercept[SQLException] { ParamBinder.bind(lp1, Map.empty) }.getMessage assert(errMsg1 == "Unresolved parameters found: $1") val errMsg2 = intercept[SQLException] { ParamBinder.bind(lp2, Map.empty) }.getMessage assert(errMsg2 == "Unresolved parameters found: $300") val errMsg3 = intercept[SQLException] { ParamBinder.bind(lp3, Map.empty) }.getMessage assert(errMsg3 == "Unresolved parameters found: $1, $2") } }