Example 1
Source File: MergeCommand.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.datasources.hiveacid.sql.catalyst.plans.command

import com.qubole.spark.hiveacid.HiveAcidErrors
import com.qubole.spark.hiveacid.datasource.HiveAcidRelation
import com.qubole.spark.hiveacid.merge.{MergeCondition, MergeWhenClause, MergeWhenNotInsert}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.{Row, SparkSession, SqlUtils}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation

case class MergeCommand(targetTable: LogicalPlan,
                        sourceTable: LogicalPlan,
                        matched: Seq[MergeWhenClause],
                        notMatched: Option[MergeWhenClause],
                        mergeCondition: MergeCondition,
                        sourceAlias: Option[AliasIdentifier],
                        targetAlias: Option[AliasIdentifier])
  extends RunnableCommand {

  override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable)
  override def output: Seq[Attribute] = Seq.empty
  override lazy val resolved: Boolean = childrenResolved
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val insertClause: Option[MergeWhenNotInsert] = notMatched match {
      case Some(i: MergeWhenNotInsert) => Some(i)
      case None => None
      case _ => throw HiveAcidErrors.mergeValidationError("WHEN NOT Clause has to be INSERT CLAUSE")

    children.head match {
      case LogicalRelation(relation: HiveAcidRelation, _, _ , _) =>
        relation.merge(SqlUtils.logicalPlanToDataFrame(sparkSession, sourceTable),
          mergeCondition.expression, matched, insertClause, sourceAlias, targetAlias)
      case SubqueryAlias(_, LogicalRelation(relation: HiveAcidRelation, _, _, _)) =>
        relation.merge(SqlUtils.logicalPlanToDataFrame(sparkSession, sourceTable),
          mergeCondition.expression, matched, insertClause, sourceAlias, targetAlias)
      case _ => throw HiveAcidErrors.tableNotAcidException(targetTable.toString())

Example 2
Source File: ExistingRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}

object RDDConversions {
  def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
    data.mapPartitions { iterator =>
      val numColumns = outputTypes.length
      val mutableRow = new GenericMutableRow(numColumns)
      val converters = { r =>
        var i = 0
        while (i < numColumns) {
          mutableRow(i) = converters(i)(r.productElement(i))
          i += 1


case class PhysicalRDD(
    output: Seq[Attribute],
    rdd: RDD[InternalRow],
    override val nodeName: String,
    override val metadata: Map[String, String] = Map.empty,
    override val outputsUnsafeRows: Boolean = false)
  extends LeafNode {

  protected override def doExecute(): RDD[InternalRow] = rdd

  override def simpleString: String = {
    val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value"
    s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}"

private[sql] object PhysicalRDD {
  // Metadata keys
  val INPUT_PATHS = "InputPaths"
  val PUSHED_FILTERS = "PushedFilters"

  def createFromDataSource(
      output: Seq[Attribute],
      rdd: RDD[InternalRow],
      relation: BaseRelation,
      metadata: Map[String, String] = Map.empty): PhysicalRDD = {
    // All HadoopFsRelations output UnsafeRows
    val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation]
    PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows)
Example 3
Source File: ExtraStrategiesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))

  override def children: Seq[SparkPlan] = Nil

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      sqlContext.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
        Row("so fast"))

      checkAnswer("a", "b"),
        Row("so slow", 1))
    } finally {
      sqlContext.experimental.extraStrategies = Nil
Example 4
Source File: GenomicIntervalStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.utvf

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{DataFrame, GenomicInterval, SparkSession, Strategy}
import org.apache.spark.unsafe.types.UTF8String

case class GIntervalRow(contigName: String, start: Int, end: Int)
class GenomicIntervalStrategy( spark: SparkSession) extends Strategy with Serializable  {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {

    case GenomicInterval(contigName, start, end,output) => GenomicIntervalPlan(plan,spark,GIntervalRow(contigName,start,end),output) :: Nil
    case _ => Nil


case class GenomicIntervalPlan(plan: LogicalPlan, spark: SparkSession,interval:GIntervalRow, output: Seq[Attribute]) extends SparkPlan with Serializable {
  def doExecute(): org.apache.spark.rdd.RDD[InternalRow] = {
    import spark.implicits._

    lazy val genomicInterval = spark.createDataset(Seq(interval))
        val proj =  UnsafeProjection.create(schema)
  def children: Seq[SparkPlan] = Nil
Example 5
Source File: SeQuiLaAnalyzer.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.ResolveTableValuedFunctionsSeq
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf

import scala.util.Random

class SeQuiLaAnalyzer(catalog: SessionCatalog, conf: SQLConf) extends Analyzer(catalog, conf, conf.optimizerMaxIterations){
  //override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(ResolveTableValuedFunctionsSeq)

  //  override lazy val batches: Seq[Batch] = Seq(
  //    Batch("Custeom", fixedPoint, ResolveTableValuedFunctionsSeq),
  //    Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf),
  //      ResolveHints.RemoveAllHints))

  var sequilaOptmazationRules: Seq[Rule[LogicalPlan]] = Nil

  override lazy val batches: Seq[Batch] = Seq(
    Batch("Hints", fixedPoint,
      new ResolveHints.ResolveBroadcastHints(conf),
    Batch("Simple Sanity Check", Once,
    Batch("Substitution", fixedPoint,
      new SubstituteUnresolvedOrdinals(conf)),
    Batch("Resolution", fixedPoint,
      ResolveTableValuedFunctionsSeq ::
      ResolveRelations ::
        ResolveReferences ::
        ResolveCreateNamedStruct ::
        ResolveDeserializer ::
        ResolveNewInstance ::
        ResolveUpCast ::
        ResolveGroupingAnalytics ::
        ResolvePivot ::
        ResolveOrdinalInOrderByAndGroupBy ::
        ResolveAggAliasInGroupBy ::
        ResolveMissingReferences ::
        ExtractGenerator ::
        ResolveGenerate ::
        ResolveFunctions ::
        ResolveAliases ::
        ResolveSubquery ::
        ResolveSubqueryColumnAliases ::
        ResolveWindowOrder ::
        ResolveWindowFrame ::
        ResolveNaturalAndUsingJoin ::

        ExtractWindowExpressions ::
        GlobalAggregates ::
        ResolveAggregateFunctions ::
        TimeWindowing ::
        ResolveInlineTables(conf) ::
        ResolveTimeZone(conf) ::
        TypeCoercion.typeCoercionRules(conf) ++
          extendedResolutionRules : _*),
    Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
    Batch("SeQuiLa", Once,sequilaOptmazationRules: _*), //SeQuilaOptimization rules
    Batch("View", Once,
    Batch("Nondeterministic", Once,
    Batch("UDF", Once,
    Batch("FixNullability", Once,
    Batch("Subquery", Once,
    Batch("Cleanup", fixedPoint,

Example 6
Source File: PileupStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.pileup

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{PileupTemplate, SparkSession, Strategy}
import org.biodatageeks.sequila.datasources.BAM.BDGAlignFileReaderWriter
import org.biodatageeks.sequila.datasources.InputDataType
import org.biodatageeks.sequila.inputformats.BDGAlignInputFormat
import org.biodatageeks.sequila.utils.TableFuncs
import org.seqdoop.hadoop_bam.{BAMBDGInputFormat, CRAMBDGInputFormat}

import scala.reflect.ClassTag

class PileupStrategy (spark:SparkSession) extends Strategy with Serializable {
  override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
    plan match {
      case PileupTemplate(tableName, sampleId, refPath, output) =>
        val inputFormat = TableFuncs.getTableMetadata(spark, tableName).provider
        inputFormat match {
          case Some(f) =>
            if (f == InputDataType.BAMInputDataType)
              PileupPlan[BAMBDGInputFormat](plan, spark, tableName, sampleId, refPath, output) :: Nil
            else if (f == InputDataType.CRAMInputDataType)
              PileupPlan[CRAMBDGInputFormat](plan, spark, tableName, sampleId, refPath, output) :: Nil
            else Nil
          case None => throw new RuntimeException("Only BAM and CRAM file formats are supported in pileup function.")
      case _ => Nil

case class PileupPlan [T<:BDGAlignInputFormat](plan:LogicalPlan, spark:SparkSession,
                                               refPath: String,
                                               output:Seq[Attribute])(implicit c: ClassTag[T])
  extends SparkPlan with Serializable  with BDGAlignFileReaderWriter [T]{

  override def children: Seq[SparkPlan] = Nil

  override protected def doExecute(): RDD[InternalRow] = {
   new Pileup(spark).handlePileup(tableName, sampleId, refPath, output)

Example 7
Source File: NCListsJoinStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.NCList

import org.biodatageeks.sequila.rangejoins.common.{ExtractRangeJoinKeys, ExtractRangeJoinKeysWithEquality}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{SparkSession, Strategy}
import org.biodatageeks.sequila.rangejoins.methods.NCList.NCListsJoinChromosome

class NCListsJoinStrategy(spark: SparkSession) extends Strategy with Serializable with  PredicateHelper {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case ExtractRangeJoinKeys(joinType, rangeJoinKeys, left, right) =>
      NCListsJoin(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case ExtractRangeJoinKeysWithEquality(joinType, rangeJoinKeys, left, right) =>
      NCListsJoinChromosome(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case _ =>
Example 8
Source File: IntervalTreeJoinOptim.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.IntervalTree

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, _}
import org.apache.spark.sql.internal.SQLConf

case class IntervalTreeJoinOptim(left: SparkPlan,
                                 right: SparkPlan,
                                 condition: Seq[Expression],
                                 context: SparkSession,leftLogicalPlan: LogicalPlan, righLogicalPlan: LogicalPlan) extends BinaryExecNode {
  def output = left.output ++ right.output

  lazy val (buildPlan, streamedPlan) = (left, right)

  lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1)),
    List(condition(2), condition(3)))

  @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output)
  @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys,

  protected override def doExecute(): RDD[InternalRow] = {
    val v1 = left.execute()
    val v1kv = => {
      val v1Key = buildKeyGenerator(x)

      (new IntervalWithRow[Int](v1Key.getInt(0), v1Key.getInt(1),
        x) )

    val v2 = right.execute()
    val v2kv = => {
      val v2Key = streamKeyGenerator(x)
      (new IntervalWithRow[Int](v2Key.getInt(0), v2Key.getInt(1),
        x) )

    val conf = new SQLConf()
    val v1Size =
      .sizeInBytes >0) leftLogicalPlan.stats.sizeInBytes.toLong

    val v2Size = if(righLogicalPlan
      .sizeInBytes >0) righLogicalPlan.stats.sizeInBytes.toLong
    if ( v1Size <= v2Size ) {
      val v3 = IntervalTreeJoinOptimImpl.overlapJoin(context.sparkContext, v1kv, v2kv,v1.count())
       p => {
         val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)>joiner.join(r._1.asInstanceOf[UnsafeRow],r._2.asInstanceOf[UnsafeRow]))


    else {
      val v3 = IntervalTreeJoinOptimImpl.overlapJoin(context.sparkContext, v2kv, v1kv, v2.count())
        p => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)


Example 9
Source File: IntervalTreeJoinStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.genApp

import org.biodatageeks.sequila.rangejoins.common.{ExtractRangeJoinKeys, ExtractRangeJoinKeysWithEquality}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{SparkSession, Strategy}
import org.biodatageeks.sequila.rangejoins.methods.genApp.IntervalTreeJoinChromosome

class IntervalTreeJoinStrategy(spark: SparkSession) extends Strategy with Serializable with  PredicateHelper {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case ExtractRangeJoinKeys(joinType, rangeJoinKeys, left, right) =>
      IntervalTreeJoin(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case ExtractRangeJoinKeysWithEquality(joinType, rangeJoinKeys, left, right) =>
      IntervalTreeJoinChromosome(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case _ =>
Example 10
Source File: SqlUtils.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.types.StructType

object SqlUtils {
  def convertToDF(sparkSession: SparkSession, plan : LogicalPlan): DataFrame = {
    Dataset.ofRows(sparkSession, plan)

  def resolveReferences(sparkSession: SparkSession,
                        expr: Expression,
                        planContaining: LogicalPlan, failIfUnresolved: Boolean,
                        exprName: Option[String] = None): Expression = {
    resolveReferences(sparkSession, expr, Seq(planContaining), failIfUnresolved, exprName)

  def resolveReferences(sparkSession: SparkSession,
                        expr: Expression,
                        planContaining: Seq[LogicalPlan],
                        failIfUnresolved: Boolean,
                        exprName: Option[String]): Expression = {
    val newPlan = FakeLogicalPlan(expr, planContaining)
    val resolvedExpr = sparkSession.sessionState.analyzer.execute(newPlan) match {
      case FakeLogicalPlan(resolvedExpr: Expression, _) =>
        // Return even if it did not successfully resolve
      case _ =>
      // This is unexpected
    if (failIfUnresolved) {
      resolvedExpr.flatMap(_.references).filter(!_.resolved).foreach {
        attr => {
          val failedMsg = exprName match {
            case Some(name) => s"${attr.sql} resolution in $name given these columns: "+
            case _ => s"${attr.sql} resolution failed given these columns: "+

  def hasSparkStopped(sparkSession: SparkSession): Boolean = {

  def createDataFrameUsingAttributes(sparkSession: SparkSession,
                                     rdd: RDD[Row],
                                     schema: StructType,
                                     attributes: Seq[Attribute]): DataFrame = {
    val encoder = RowEncoder(schema)
    val catalystRows =
    val logicalPlan = LogicalRDD(
      isStreaming = false)(sparkSession)
    Dataset.ofRows(sparkSession, logicalPlan)

  def analysisException(cause: String): Throwable = {
    new AnalysisException(cause)

case class FakeLogicalPlan(expr: Expression, children: Seq[LogicalPlan])
  extends LogicalPlan {
  override def output: Seq[Attribute] = children.foldLeft(Seq[Attribute]())((out, child) => out ++ child.output)
Example 11
Source File: MergePlan.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.parser.plans.logical

import com.qubole.spark.hiveacid.merge.{MergeWhenClause}
import org.apache.spark.sql.{SparkSession, SqlUtils}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}

case class MergePlan(sourcePlan: LogicalPlan,
                     targetPlan: LogicalPlan,
                     condition: Expression,
                     matched: Seq[MergeWhenClause],
                     notMatched: Option[MergeWhenClause]) extends Command {
  override def children: Seq[LogicalPlan] = Seq(sourcePlan, targetPlan)
  override def output: Seq[Attribute] = Seq.empty

object MergePlan {
  def resolve(sparkSession: SparkSession, mergePlan: MergePlan): MergePlan = {
    MergeWhenClause.validate(mergePlan.matched ++ mergePlan.notMatched)
    val resolvedCondition = SqlUtils.resolveReferences(sparkSession, mergePlan.condition,
      mergePlan.children, true, None)
    val resolvedMatched = MergeWhenClause.resolve(sparkSession, mergePlan, mergePlan.matched)
    val resolvedNotMatched = {
      x => x.resolve(sparkSession, mergePlan)

Example 12
Source File: DeleteCommand.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.datasources.hiveacid.sql.catalyst.plans.command

import com.qubole.spark.hiveacid.HiveAcidErrors
import com.qubole.spark.hiveacid.datasource.HiveAcidRelation
import org.apache.spark.sql.{Column, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation

case class DeleteCommand(
    table: LogicalPlan,
    condition: Expression)
  extends RunnableCommand {

  // We don't want `table` in children as sometimes we don't want to transform it.
  override def children: Seq[LogicalPlan] = Seq(table)
  override def output: Seq[Attribute] = Seq.empty
  override lazy val resolved: Boolean = childrenResolved
  override def run(sparkSession: SparkSession): Seq[Row] = {
    if (children.size != 1) {
      throw new IllegalArgumentException("DELETE command should specify exactly one table, whereas this has: "
        + children.size)
    children(0) match {
      case LogicalRelation(relation: HiveAcidRelation, _, _ , _) => {
        relation.delete(new Column(condition))
      case _ => throw HiveAcidErrors.tableNotAcidException(table.toString())
Example 13
Source File: QueryExecution.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

  lazy val toRdd: RDD[InternalRow] = executedPlan.execute()

  protected def stringOrError[A](f: => A): String =
    try f.toString catch { case e: Throwable => e.toString }

  def simpleString: String = {
    s"""== Physical Plan ==

  override def toString: String = {
    def output = => s"${}: ${o.dataType.simpleString}").mkString(", ")

    s"""== Parsed Logical Plan ==
       |== Analyzed Logical Plan ==
       |== Optimized Logical Plan ==
       |== Physical Plan ==
Example 14
Source File: UpdateCommand.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.datasources.hiveacid.sql.catalyst.plans.command

import com.qubole.spark.hiveacid.HiveAcidErrors
import com.qubole.spark.hiveacid.datasource.HiveAcidRelation
import org.apache.spark.sql.{Column, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation

case class UpdateCommand(
    table: LogicalPlan,
    setExpressions: Map[String, Expression],
    condition: Option[Expression])
  extends RunnableCommand {

  override def children: Seq[LogicalPlan] = Seq(table)
  override def output: Seq[Attribute] = Seq.empty
  override lazy val resolved: Boolean = childrenResolved

  override def run(sparkSession: SparkSession): Seq[Row] = {
    if (children.size != 1) {
      throw new IllegalArgumentException("UPDATE command should have one table to update, whereas this has: "
        + children.size)
    children(0) match {
      case LogicalRelation(relation: HiveAcidRelation, _, _ , _) => {
        val setColumns = setExpressions.mapValues(expr => new Column(expr))
        val updateFilterColumn = Column(_))
        relation.update(updateFilterColumn, setColumns)
      case LogicalRelation(_, _, Some(catalogTable), _) =>
        throw HiveAcidErrors.tableNotAcidException(catalogTable.qualifiedName)
      case _ => throw HiveAcidErrors.tableNotAcidException(table.toString())
Example 15
Source File: HiveAcidRelation.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.hiveacid.datasource

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan}
import org.apache.spark.sql.types._
import com.qubole.spark.hiveacid.{HiveAcidErrors, HiveAcidTable, SparkAcidConf}
import com.qubole.spark.hiveacid.hive.HiveAcidMetadata
import com.qubole.spark.hiveacid.merge.{MergeWhenClause, MergeWhenNotInsert}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import collection.JavaConversions._

case class HiveAcidRelation(sparkSession: SparkSession,
                            fullyQualifiedTableName: String,
                            parameters: Map[String, String])
    extends BaseRelation
    with InsertableRelation
    with PrunedFilteredScan
    with Logging {

  private val hiveAcidMetadata: HiveAcidMetadata = HiveAcidMetadata.fromSparkSession(
  private val hiveAcidTable: HiveAcidTable = new HiveAcidTable(sparkSession,
    hiveAcidMetadata, parameters)

  private val readOptions = SparkAcidConf(sparkSession, parameters)

  override def sqlContext: SQLContext = sparkSession.sqlContext

  override val schema: StructType = if (readOptions.includeRowIds) {
  } else {

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
   // sql insert into and overwrite
    if (overwrite) {
    } else {

  def update(condition: Option[Column], newValues: Map[String, Column]): Unit = {
    hiveAcidTable.update(condition, newValues)

  def delete(condition: Column): Unit = {
  override def sizeInBytes: Long = {
    val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor
    (sparkSession.sessionState.conf.defaultSizeInBytes * compressionFactor).toLong

  def merge(sourceDf: DataFrame,
            mergeExpression: Expression,
            matchedClause: Seq[MergeWhenClause],
            notMatched: Option[MergeWhenNotInsert],
            sourceAlias: Option[AliasIdentifier],
            targetAlias: Option[AliasIdentifier]): Unit = {
    hiveAcidTable.merge(sourceDf, mergeExpression, matchedClause,
      notMatched, sourceAlias, targetAlias)

  def getHiveAcidTable(): HiveAcidTable = {

  // FIXME: should it be true / false. Recommendation seems to
  //  be to leave it as true
  override val needConversion: Boolean = false

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    val readOptions = SparkAcidConf(sparkSession, parameters)
    // sql "select *"
    hiveAcidTable.getRdd(requiredColumns, filters, readOptions)
Example 16
Source File: HiveAcidAutoConvert.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.hiveacid

import java.util.Locale

import com.qubole.spark.datasources.hiveacid.sql.execution.SparkAcidSqlParser
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import com.qubole.spark.hiveacid.datasource.HiveAcidDataSource

case class HiveAcidAutoConvert(spark: SparkSession) extends Rule[LogicalPlan] {

  private def isConvertible(relation: HiveTableRelation): Boolean = {
    val serde ="").toLowerCase(Locale.ROOT)"transactional", "false").toBoolean

  private def convert(relation: HiveTableRelation): LogicalRelation = {
    val options = ++ ++ Map("table" -> relation.tableMeta.qualifiedName)

    val newRelation = new HiveAcidDataSource().createRelation(spark.sqlContext, options)
    LogicalRelation(newRelation, isStreaming = false)

  override def apply(plan: LogicalPlan): LogicalPlan = {
    plan resolveOperators {
      // Write path
      case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists)
        if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && isConvertible(r) =>
        InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists)

      // Read path
      case relation: HiveTableRelation
        if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) =>

class HiveAcidAutoConvertExtension extends (SparkSessionExtensions => Unit) {
  def apply(extension: SparkSessionExtensions): Unit = {
    extension.injectParser { (session, parser) =>
Example 17
Source File: SpatialJoin.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.plans

import org.apache.spark.sql.simba.expression.{InCircleRange, InKNN}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, LogicalPlan}
import org.apache.spark.sql.types.BooleanType

case class SpatialJoin(left: LogicalPlan, right: LogicalPlan, joinType: SpatialJoinType,
                       condition: Option[Expression]) extends BinaryNode {
  override def output: Seq[Attribute] = {
    joinType match {
      case KNNJoin =>
        left.output ++ right.output
      case ZKNNJoin =>
        left.output ++ right.output
      case DistanceJoin =>
        left.output ++
      case _ =>
        left.output ++ right.output

  def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty

  // Joins are only resolved if they don't introduce ambiguous expression ids.
  override lazy val resolved: Boolean = {
    childrenResolved &&
      expressions.forall(_.resolved) &&
      selfJoinResolved &&
      condition.forall(_.dataType == BooleanType)
Example 18
Source File: SimbaOptimizer.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkOptimizer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.simba.plans.SpatialJoin

class SimbaOptimizer(catalog: SessionCatalog,
                     conf: SQLConf,
                     experimentalMethods: ExperimentalMethods)
 extends SparkOptimizer(catalog, conf, experimentalMethods) {
  override def batches: Seq[Batch] = super.batches :+
    Batch("SpatialJoinPushDown", FixedPoint(100), PushPredicateThroughSpatialJoin)

object PushPredicateThroughSpatialJoin extends Rule[LogicalPlan] with PredicateHelper {
  private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
    val (leftEvaluateCondition, rest) =
      condition.partition(_.references subsetOf left.outputSet)
    val (rightEvaluateCondition, commonCondition) =
      rest.partition(_.references subsetOf right.outputSet)

    (leftEvaluateCondition, rightEvaluateCondition, commonCondition)

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    // push the where condition down into join filter
    case f @ Filter(filterCondition, SpatialJoin(left, right, joinType, joinCondition)) =>
      val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
        split(splitConjunctivePredicates(filterCondition), left, right)

      val newLeft = leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
      SpatialJoin(newLeft, newRight, joinType, newJoinCond)

    // push down the join filter into sub query scanning if applicable
    case f @ SpatialJoin(left, right, joinType, joinCondition) =>
      val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
        split(, left, right)

      val newLeft = leftJoinConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightJoinConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = commonJoinCondition.reduceLeftOption(And)

      SpatialJoin(newLeft, newRight, joinType, newJoinCond)
Example 19
Source File: QueryExecution.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.execution

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, QueryExecution => SQLQueryExecution}
import org.apache.spark.sql.simba.SimbaSession

class QueryExecution(val simbaSession: SimbaSession, override val logical: LogicalPlan)
  extends SQLQueryExecution(simbaSession, logical) {

  lazy val withIndexedData: LogicalPlan = {

  override lazy val optimizedPlan: LogicalPlan = {

  override lazy val sparkPlan: SparkPlan ={

Example 20
Source File: IndexedRelation.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.index

import org.apache.spark.sql.simba.SimbaSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan

private[simba] case class IPartition(data: Array[InternalRow], index: Index)

private[simba] object IndexedRelation {
  def apply(child: SparkPlan, table_name: Option[String], index_type: IndexType,
            column_keys: List[Attribute], index_name: String): IndexedRelation = {
    index_type match {
      case TreeMapType =>
        TreeMapIndexedRelation(child.output, child, table_name, column_keys, index_name)()
      case TreapType =>
        TreapIndexedRelation(child.output, child, table_name, column_keys, index_name)()
      case RTreeType =>
        RTreeIndexedRelation(child.output, child, table_name, column_keys, index_name)()
      case HashMapType =>
        HashMapIndexedRelation(child.output, child, table_name, column_keys, index_name)()
      case _ => null

private[simba] abstract class IndexedRelation extends LogicalPlan {
  self: Product =>
  var _indexedRDD: IndexedRDD
  def indexedRDD: IndexedRDD = _indexedRDD

  def simbaSession = SimbaSession.getActiveSession.orNull

  override def children: Seq[LogicalPlan] = Nil
  def output: Seq[Attribute]

  def withOutput(newOutput: Seq[Attribute]): IndexedRelation
Example 21
Source File: ShapeUtils.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.util

import org.apache.spark.sql.simba.{ShapeSerializer, ShapeType}
import org.apache.spark.sql.simba.expression.PointWrapper
import org.apache.spark.sql.simba.spatial.{Point, Shape}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan

object ShapeUtils {
  def getPointFromRow(row: InternalRow, columns: List[Attribute], plan: SparkPlan,
                      isPoint: Boolean): Point = {
    if (isPoint) {
      ShapeSerializer.deserialize(BindReferences.bindReference(columns.head, plan.output)
    } else {
      Point(, plan.output).eval(row)
  def getPointFromRow(row: InternalRow, columns: List[Attribute], plan: LogicalPlan,
                      isPoint: Boolean): Point = {
    if (isPoint) {
      ShapeSerializer.deserialize(BindReferences.bindReference(columns.head, plan.output)
    } else {
      Point(, plan.output).eval(row)

  def getShape(expression: Expression, input: InternalRow): Shape = {
    if (!expression.isInstanceOf[PointWrapper] && expression.dataType.isInstanceOf[ShapeType]) {
    } else if (expression.isInstanceOf[PointWrapper]) {
    } else throw new UnsupportedOperationException("Query shape should be of ShapeType")

  def getShape(expression: Expression, schema: Seq[Attribute], input: InternalRow): Shape = {
    if (!expression.isInstanceOf[PointWrapper] && expression.dataType.isInstanceOf[ShapeType]) {
      ShapeSerializer.deserialize(BindReferences.bindReference(expression, schema)
    } else if (expression.isInstanceOf[PointWrapper]) {
      BindReferences.bindReference(expression, schema).eval(input).asInstanceOf[Shape]
    } else throw new UnsupportedOperationException("Query shape should be of ShapeType")

Example 22
Source File: ParamBinder.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service

import java.sql.SQLException

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder

object ParamBinder {

  def bind(logicalPlan: LogicalPlan, params: Map[Int, Literal]): LogicalPlan = {
    val boundPlan = logicalPlan.transformAllExpressions {
      case ParameterPlaceHolder(id) if params.contains(id) =>
    val unresolvedParams = boundPlan.flatMap { plan =>
      plan.expressions.flatMap { _.flatMap {
        case ParameterPlaceHolder(id) => Some(id)
        case _ => None
    if (unresolvedParams.nonEmpty) {
      throw new SQLException("Unresolved parameters found: " + => s"$$$n").mkString(", "))
Example 23
Source File: OperationManager.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service

import javax.annotation.concurrent.ThreadSafe

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.server.SQLServerConf._
import org.apache.spark.sql.server.SQLServerEnv
import org.apache.spark.sql.types.StructType

sealed trait OperationState
case object INITIALIZED extends OperationState
case object RUNNING extends OperationState
case object FINISHED extends OperationState
case object CANCELED extends OperationState
case object CLOSED extends OperationState
case object ERROR extends OperationState
case object UNKNOWN extends OperationState
case object PENDING extends OperationState

sealed trait OperationType {
  override def toString: String = getClass.getSimpleName.stripSuffix("$")

object BEGIN extends OperationType
object FETCH extends OperationType
object SELECT extends OperationType

trait Operation {

  private val timeout = SQLServerEnv.sqlConf.sqlServerIdleOperationTimeout
  private var lastAccessTime: Long = System.currentTimeMillis()

  @volatile protected var state: OperationState = INITIALIZED

  def statementId(): String
  def outputSchema(): StructType
  def prepare(params: Map[Int, Literal]): Unit
  def run(): Iterator[InternalRow]
  def cancel(): Unit
  def close(): Unit

  protected def setState(newState: OperationState): Unit = {
    lastAccessTime = System.currentTimeMillis()
    state = newState

  def isTimeOut(current: Long): Boolean = {
    if (timeout == 0) {
    } else if (timeout > 0) {
      Seq(FINISHED, CANCELED, CLOSED, ERROR).contains(state) &&
        lastAccessTime + timeout <= current
    } else {
      lastAccessTime - timeout <= current

object NOP extends Operation {

  override val statementId: String = "nop"
  override val outputSchema: StructType = new StructType()

  override def prepare(params: Map[Int, Literal]): Unit = {}
  override def run(): Iterator[InternalRow] = Iterator.empty
  override def cancel(): Unit = {}
  override def close(): Unit = {}

trait OperationExecutor {

  // Creates a new instance for service-specific operations
  def newOperation(
    sessionState: SessionState,
    statementId: String,
    query: (String, LogicalPlan)): Operation
Example 24
Source File: operators.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, LeafNode, LogicalPlan, Statistics, UnaryNode}

case class Recursion(name: String,
                     isLinear: Boolean,
                     left: LogicalPlan,
                     right: LogicalPlan,
                     partitioning: Seq[Int]) extends BinaryNode {
  // left is exitRules plan
  // right is recursive rules plan
  override def output: Seq[Attribute] = right.output

case class MutualRecursion(name: String,
                             isLinear: Boolean,
                             left: LogicalPlan,
                             right: LogicalPlan,
                             partitioning: Seq[Int]) extends BinaryNode {

  override def output: Seq[Attribute] = right.output

  override def children: Seq[LogicalPlan] = {
    if (left == null)
      Seq(left, right)

  override def generateTreeString(depth: Int,
                         lastChildren: Seq[Boolean],
                         builder: StringBuilder): StringBuilder = {
    if (depth > 0) {
      lastChildren.init.foreach { isLast =>
        val prefixFragment = if (isLast) "   " else ":  "

      val branch = if (lastChildren.last) "+- " else ":- "


    if (children.nonEmpty) {
      val exitRule = children.init
      if (exitRule != null)
        exitRule.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
      children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)


case class LinearRecursiveRelation(_name: String, output: Seq[Attribute], partitioning: Seq[Int]) extends LeafNode {
  override def statistics: Statistics = Statistics(Long.MaxValue)
  var name = _name

case class NonLinearRecursiveRelation(_name: String, output: Seq[Attribute], partitioning: Seq[Int]) extends LeafNode {
  override def statistics: Statistics = Statistics(Long.MaxValue)

  def name = "all_" + _name

case class MonotonicAggregate(groupingExpressions: Seq[Expression],
                              aggregateExpressions: Seq[NamedExpression],
                              child: LogicalPlan,
                              partitioning: Seq[Int]) extends UnaryNode {
  override lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved

  override def output: Seq[Attribute] =

case class AggregateRecursion(name: String,
                              isLinear: Boolean,
                              left: LogicalPlan,
                              right: LogicalPlan,
                              partitioning: Seq[Int]) extends BinaryNode {
  // left is exitRules plan
  // right is recursive rules plan
  override def output: Seq[Attribute] = right.output

case class AggregateRelation(_name: String, output: Seq[Attribute], partitioning: Seq[Int]) extends LeafNode {
  override def statistics: Statistics = Statistics(Long.MaxValue)
  var name = _name

case class CacheHint(child: LogicalPlan) extends UnaryNode {
  override def output: Seq[Attribute] = child.output
Example 25
Source File: ExtraStrategiesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
        Row("so fast"))

      checkAnswer("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
Example 26
Source File: SparkSchemaProvider.scala    From mimir   with Apache License 2.0 5 votes vote down vote up

import com.typesafe.scalalogging.LazyLogging
import org.apache.spark.sql.{ DataFrame, SaveMode } 
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.analysis.{ UnresolvedRelation, NoSuchDatabaseException }
import org.apache.spark.sql.execution.command.{ DropTableCommand, CreateDatabaseCommand }

import mimir.Database
import mimir.algebra._
import mimir.exec.spark.{MimirSpark, RAToSpark, RowIndexPlan}

class SparkSchemaProvider(db: Database)
  extends LogicalPlanSchemaProvider
  with MaterializedTableProvider
  with LazyLogging

  def listTables(): Seq[ID] = 
    try {
      val tables = 
                .map { col => (
                  ) }
      } else { 
        logger.trace(s"$table doesn't exist")
    } catch {
      case _:NoSuchDatabaseException => {
        logger.warn("Couldn't find database!!! ($sparkDBName)")

  def logicalplan(table: ID): LogicalPlan =

  def createStoredTableAs(data: DataFrame, name: ID)

  def dropStoredTable(name: ID)
      TableIdentifier(, None),//Option(sparkDBName)), 
      true, false, true
Example 27
Source File: SampleRowsUniformly.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.algebra.sampling

import org.apache.spark.sql.functions.{ rand, udf }
import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, Filter }
import play.api.libs.json._
import mimir.algebra._

case class SampleRowsUniformly(probability:Double) extends SamplingMode
  override def toString = s"WITH PROBABILITY $probability"

  def apply(plan: LogicalPlan, seed: Long): LogicalPlan = 
    // Adapted from Spark's df.stat.sampleBy method
    val r = rand(seed)
    val f = udf { (x: Double) => x < probability }
  def expressions: Seq[Expression] = Seq()
  def rebuildExpressions(x: Seq[Expression]): SamplingMode = this

  def toJson: JsValue = JsObject(Map[String,JsValue](
    "mode" -> JsString(SampleRowsUniformly.MODE),
    "probability" -> JsNumber(probability)

object SampleRowsUniformly
  val MODE = "uniform_probability"

  def parseJson(json:Map[String, JsValue]): Option[SampleRowsUniformly] =
    } else {
Example 28
Source File: SamplingMode.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.algebra.sampling

import play.api.libs.json._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import mimir.algebra.{RAException, Expression}

trait SamplingMode
  def apply(plan: LogicalPlan, seed: Long): LogicalPlan
  def expressions: Seq[Expression]
  def rebuildExpressions(x: Seq[Expression]): SamplingMode
  def toJson: JsValue

object SamplingMode
  val parsers = Seq[Map[String,JsValue] => Option[SamplingMode]](

  def fromJson(v: JsValue): SamplingMode =
    val config =[Map[String,JsValue]]
    for(parser <- parsers){
      parser(config) match {
        case Some(s) => return s
        case None => ()
    throw new RAException(s"Invalid Sampling Mode: $v")

  implicit val format = Format[SamplingMode](
    new Reads[SamplingMode]{ def reads(v: JsValue) = JsSuccess(fromJson(v)) },
    new Writes[SamplingMode]{ def writes(mode: SamplingMode) = mode.toJson }
Example 29
Source File: SampleStratifiedOn.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.algebra.sampling

import org.apache.spark.sql.functions.{rand, udf, col}
import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, Filter }
import play.api.libs.json._
import mimir.algebra._
import mimir.exec.spark.RAToSpark
import mimir.serialization.{ Json => MimirJson }

case class SampleStratifiedOn(column:ID, t:Type, strata:Map[PrimitiveValue,Double]) extends SamplingMode
  val sparkStrata = { case (v, p) => RAToSpark.getNative(v, t) -> p }

  override def toString = s"ON $column WITH STRATA ${ { case (v,p) => s"$v -> $p"}.mkString(" | ")}"

  def apply(plan: LogicalPlan, seed: Long): LogicalPlan =
    // Adapted from Spark's df.stat.sampleBy method
    val c = col(
    val r = rand(seed)
    val f = udf { (stratum: Any, x: Double) =>
              x < sparkStrata.getOrElse(stratum, 0.0)
      f(c, r).expr,
  def expressions: Seq[Expression] = Seq(Var(column))
  def rebuildExpressions(x: Seq[Expression]): SamplingMode =
    x(0) match { 
      case Var(newColumn) => SampleStratifiedOn(newColumn, t, strata) 
      case _ => throw new RAException("Internal Error: Rewriting stratification variable with arbitrary expression")

  def toJson: JsValue = JsObject(Map[String,JsValue](
    "mode" -> JsString(SampleStratifiedOn.MODE),
    "column" -> JsString(,
    "type" -> MimirJson.ofType(t),
    "strata" -> JsArray(
        .map { case (v, p) => JsObject(Map[String,JsValue](
            "value" -> MimirJson.ofPrimitive(v),
            "probability" -> JsNumber(p)

object SampleStratifiedOn
  val MODE = "stratified_on"

  def parseJson(json:Map[String, JsValue]): Option[SampleStratifiedOn] =
      val t = MimirJson.toType(json("type"))

          .map { stratum => 
            MimirJson.toPrimitive(t, stratum("value")) -> 
    } else {
Example 30
Source File: StarryAggStrategy.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.stanlee.execution.aggregate.StarryAggUtils

case class StarryAggStrategy() extends Strategy {
  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case PhysicalAggregation(
    groupingExpressions, aggregateExpressions, resultExpressions, child) =>

      val (functionsWithDistinct, functionsWithoutDistinct) =
      if ( > 1) {
        // This is a sanity check. We should not reach here when we have multiple distinct
        // column sets. Our MultipleDistinctRewriter should take care this case.
        sys.error("You hit a query analyzer bug. Please report your query to " +
          "Spark user mailing list.")

      val aggregateOperator =
        if (functionsWithDistinct.isEmpty) {
        } else {


    case _ => Nil
Example 31
Source File: treenodes.scala    From ingraph   with Eclipse Public License 1.0 5 votes vote down vote up
package ingraph.model.treenodes

import ingraph.compiler.exceptions.UnexpectedTypeException
import ingraph.model.expr.{HasExtraChildren, RichEdgeAttribute}
import ingraph.model.fplan.FNode
import ingraph.model.nplan.NNode
import ingraph.model.gplan.GNode
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.{expressions => cExpr}

abstract class GenericLeafNode[TPlan <: LogicalPlan] extends LogicalPlan {
  override final def children: Seq[TPlan] = Nil

abstract class GenericUnaryNode[TPlan <: LogicalPlan] extends LogicalPlan {
  def child: TPlan

  override final def children: Seq[TPlan] = child :: Nil

abstract class GenericBinaryNode[TPlan <: LogicalPlan] extends LogicalPlan {
  def left: TPlan
  def right: TPlan

  override final def children: Seq[TPlan] = Seq(left, right)

  def find(p: IngraphTreeNodeType => Boolean, startNode: IngraphTreeNodeType): Option[IngraphTreeNodeType] = {
    if (p(startNode)) {
    } else {
      ingraphChildren(startNode).foldLeft[Option[IngraphTreeNodeType]](Option.empty)( (b, a) => b.orElse(find(p, a)))
Example 32
Source File: GraphTexConverter.scala    From ingraph   with Eclipse Public License 1.0 5 votes vote down vote up
package ingraph.util.tex

import ingraph.model.plan.TProduction
import ingraph.model.treenodes.GenericBinaryNode
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

class GraphTexConverter[T <: LogicalPlan] extends TexConverter[T] {
  override def toStandaloneDoc(tex: String, ingraphDir: String): String = {
    val texGraph = s"\\begin{forest} $tex \\end{forest}"
    super.toStandaloneDoc(texGraph, ingraphDir)

  override def toTex(p: LogicalPlan): String = {
    p match {
      case p: TProduction => toTex(p.child)  // ignore production operator
      case b: GenericBinaryNode[T] => s"[${binaryToTeX(b)} ${toTex(b.left)} ${toTex(b.right)}]\n"
      case _=> s"[${super.toTex(p)}]\n"
Example 33
Source File: GPlanExpander.scala    From ingraph   with Eclipse Public License 1.0 5 votes vote down vote up
package ingraph.compiler.cypher2gplan

import ingraph.model.{expr, gplan}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.{expressions => cExpr}

object GPlanExpander {
  def expandGPlan(rawQueryPlan: gplan.GNode): gplan.GNode = {
    // should there be other rule sets (partial functions), combine them using orElse,
    // e.g. pfunc1 orElse pfunc2
    // expanding GetVertices involves creating other GetVertices, so transformUp is to avoid infinite recursion
    val full = rawQueryPlan.transformUp(gplanExpander)


  val gplanExpander: PartialFunction[LogicalPlan, LogicalPlan] = {
    // Nullary
    case gplan.GetVertices(vertexAttribute) if => {
      val condition: Expression = propertyMapToCondition(,
      gplan.Selection(condition, gplan.GetVertices(vertexAttribute))
    case gplan.Expand(srcVertexAttribute, trgVertexAttribute, edge, dir, child) if || => {
      val selectionOnEdge = gplan.Selection(propertyMapToCondition(,, gplan.Expand(srcVertexAttribute, trgVertexAttribute, edge, dir, child))
      val selectionOnTargetVertex = gplan.Selection(propertyMapToCondition(,, selectionOnEdge)

  def propertyMapToCondition(properties: expr.types.TPropertyMap, baseName: String): Expression = { (p) => cExpr.EqualTo(UnresolvedAttribute(Seq(baseName, p._1)), p._2) )
              .foldLeft[Expression]( cExpr.Literal(true) )( (b, a) => cExpr.And(b, a) )
Example 34
Source File: RecursivePlanDetails.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import scala.collection.mutable.HashMap

class RecursivePlanDetails extends Serializable {
  private val baseRelationsByName = new HashMap[String, LogicalPlan]
  val recursiveRelations = new HashMap[String, LogicalPlan]
  val aggregateRelations = new HashMap[String, LogicalPlan]

  def addBaseRelation(name: String, obj: LogicalPlan) = {
    baseRelationsByName.put(name, obj)

  def containsBaseRelation(name: String): Boolean = {
Example 35
Source File: SparkPlannerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec( :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data, _) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
Example 36
Source File: BigDatalogProgram.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark

import edu.ucla.cs.wis.bigdatalog.interpreter.OperatorProgram
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.{DataFrame, Row}

class BigDatalogProgram(var bigDatalogContext: BigDatalogContext,
                        plan: LogicalPlan,
                        operatorProgram: OperatorProgram) {

  def toDF(): DataFrame = {
    new DataFrame(bigDatalogContext, plan)
  def count(): Long = {

  // use this method to produce an rdd containing the results for the program (i.e., it evaluates the program)
  def execute(): RDD[Row] = {

  override def toString(): String = {
    new QueryExecution(bigDatalogContext, plan).toString
Example 37
Source File: HiveStrategies.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _}
import org.apache.spark.sql.hive.execution._

private[hive] trait HiveStrategies {
  // Possibly being too clever with types here... or not clever enough.
  self: SQLContext#SparkPlanner =>

  val hiveContext: HiveContext

  object Scripts extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) =>
        ScriptTransformation(input, script, output, planLater(child), schema)(hiveContext) :: Nil
      case _ => Nil

  object DataSinks extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.InsertIntoTable(
          table: MetastoreRelation, partition, child, overwrite, ifNotExists) =>
          table, partition, planLater(child), overwrite, ifNotExists) :: Nil
      case hive.InsertIntoHiveTable(
          table: MetastoreRelation, partition, child, overwrite, ifNotExists) =>
          table, partition, planLater(child), overwrite, ifNotExists) :: Nil
      case _ => Nil

  object HiveTableScans extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) =>
        // Filter out all predicates that only deal with partition keys, these are given to the
        // hive table scan operator to be used for partition pruning.
        val partitionKeyIds = AttributeSet(relation.partitionKeys)
        val (pruningPredicates, otherPredicates) = predicates.partition { predicate =>
          !predicate.references.isEmpty &&

          HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
      case _ =>

  object HiveDDLStrategy extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case CreateTableUsing(
        tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) =>
        val cmd =
            tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)
        ExecutedCommand(cmd) :: Nil

      case CreateTableUsingAsSelect(
        tableIdent, provider, false, partitionCols, mode, opts, query) =>
        val cmd =
          CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query)
        ExecutedCommand(cmd) :: Nil

      case _ => Nil

  case class HiveCommandStrategy(context: HiveContext) extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case describe: DescribeCommand =>
        val resolvedTable = context.executePlan(describe.table).analyzed
        resolvedTable match {
          case t: MetastoreRelation =>
              DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil

          case o: LogicalPlan =>
            val resultPlan = context.executePlan(o).executedPlan
              resultPlan, describe.output, describe.isExtended)) :: Nil

      case _ => Nil
Example 38
Source File: CreateTableAsSelect.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable}
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation}
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}

case class CreateTableAsSelect(
    tableDesc: HiveTable,
    query: LogicalPlan,
    allowExisting: Boolean)
  extends RunnableCommand {

  val tableIdentifier = TableIdentifier(, Some(tableDesc.database))

  override def children: Seq[LogicalPlan] = Seq(query)

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val hiveContext = sqlContext.asInstanceOf[HiveContext]
    lazy val metastoreRelation: MetastoreRelation = {
      import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
      import org.apache.hadoop.mapred.TextInputFormat

      val withFormat =
          inputFormat =
          outputFormat =
              .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)),
          serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName())))

      val withSchema = if (withFormat.schema.isEmpty) {
        // Hive doesn't support specifying the column list for target table in CTAS
        // However we don't think SparkSQL should follow that.
        tableDesc.copy(schema = =>
          HiveColumn(, HiveMetastoreTypes.toMetastoreType(c.dataType), null)))
      } else {


      // Get the Metastore Relation
      hiveContext.catalog.lookupRelation(tableIdentifier, None) match {
        case r: MetastoreRelation => r
    // TODO ideally, we should get the output data ready first and then
    // add the relation into catalog, just in case of failure occurs while data
    // processing.
    if (hiveContext.catalog.tableExists(tableIdentifier)) {
      if (allowExisting) {
        // table already exists, will do nothing, to keep consistent with Hive
      } else {
        throw new AnalysisException(s"$tableIdentifier already exists.")
    } else {
      hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd


  override def argString: String = {
    s"[Database:${tableDesc.database}}, TableName: ${}, InsertIntoHiveTable]"
Example 39
Source File: SameResultSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._

class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation(', ', '
  val testRelation2 = LocalRelation(', ', '

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")

  test("relations") {
    assertSameResult(testRelation, testRelation2)

  test("projections") {
    assertSameResult('a, 'b),'a, 'b))
    assertSameResult('b, 'a),'b, 'a))

    assertSameResult(testRelation,'a), result = false)
    assertSameResult('b, 'a),'a, 'b), result = false)

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
Example 40
Source File: ConvertToLocalRelationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal =
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)

Example 41
Source File: ColumnPruningSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType

class ColumnPruningSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Column pruning", FixedPoint(100),
      ColumnPruning) :: Nil

  test("Column pruning for Generate when Generate.join = false") {
    val input = LocalRelation(', 'b.array(StringType))

    val query = input.generate(Explode('b), join = false).analyze

    val optimized = Optimize.execute(query)

    val correctAnswer ='b).generate(Explode('b), join = false).analyze

    comparePlans(optimized, correctAnswer)

  test("Column pruning for Generate when Generate.join = true") {
    val input = LocalRelation(', ', 'c.array(StringType))

    val query =
        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
        .select('a, 'explode)

    val optimized = Optimize.execute(query)

    val correctAnswer =
        .select('a, 'c)
        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
        .select('a, 'explode)

    comparePlans(optimized, correctAnswer)

  test("Turn Generate.join to false if possible") {
    val input = LocalRelation('b.array(StringType))

    val query =
        .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
        .select(('explode + 1).as("result"))

    val optimized = Optimize.execute(query)

    val correctAnswer =
        .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
        .select(('explode + 1).as("result"))

    comparePlans(optimized, correctAnswer)

  test("Column pruning for Project on Sort") {
    val input = LocalRelation(', 'b.string, 'c.double)

    val query = input.orderBy('b.asc).select('a).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer ='a, 'b).orderBy('b.asc).select('a).analyze

    comparePlans(optimized, correctAnswer)

  // todo: add more tests for column pruning
Example 42
Source File: ProjectCollapsingSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil

  val testRelation = LocalRelation(', '

  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer ='a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),

    comparePlans(optimized, correctAnswer)

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)

  test("collapse two nondeterministic, independent projects into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation

    comparePlans(optimized, correctAnswer)

  test("collapse one nondeterministic, one deterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand), 'a)
      .select(('a + 1).as('a_plus_1))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(('a + 1).as('a_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
Example 43
Source File: AggregateOptimizeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      RemoveLiteralFromGroupExpressions) :: Nil

  test("replace distinct with aggregate") {
    val input = LocalRelation(', '

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)

  test("remove literals in grouping expression") {
    val input = LocalRelation(', '

    val query =
      input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(query)

    val correctAnswer = input.groupBy('a)(sum('b))

    comparePlans(optimized, correctAnswer)
Example 44
Source File: InsertIntoDataSource.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

private[sql] case class InsertIntoDataSource(
    logicalRelation: LogicalRelation,
    query: LogicalPlan,
    overwrite: Boolean)
  extends RunnableCommand {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
    val data = DataFrame(sqlContext, query)
    // Apply the schema of the existing table to the new data.
    val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
    relation.insert(df, overwrite)

    // Invalidate the cache.

Example 45
Source File: SparkSQLParser.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.StringType

class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {

  // A parser for the key-value part of the "SET [key = [value ]]" syntax
  private object SetCommandParser extends RegexParsers {
    private val key: Parser[String] = "(?m)[^=]+".r

    private val value: Parser[String] = "(?m).*$".r

    private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)())

    private val pair: Parser[LogicalPlan] =
      (key ~ ("=".r ~> value).?).? ^^ {
        case None => SetCommand(None)
        case Some(k ~ v) => SetCommand(Some(k.trim ->

    def apply(input: String): LogicalPlan = parseAll(pair, input) match {
      case Success(plan, _) => plan
      case x => sys.error(x.toString)

  protected val AS = Keyword("AS")
  protected val CACHE = Keyword("CACHE")
  protected val CLEAR = Keyword("CLEAR")
  protected val DESCRIBE = Keyword("DESCRIBE")
  protected val EXTENDED = Keyword("EXTENDED")
  protected val FUNCTION = Keyword("FUNCTION")
  protected val FUNCTIONS = Keyword("FUNCTIONS")
  protected val IN = Keyword("IN")
  protected val LAZY = Keyword("LAZY")
  protected val SET = Keyword("SET")
  protected val SHOW = Keyword("SHOW")
  protected val TABLE = Keyword("TABLE")
  protected val TABLES = Keyword("TABLES")
  protected val UNCACHE = Keyword("UNCACHE")

  override protected lazy val start: Parser[LogicalPlan] =
    cache | uncache | set | show | desc | others

  private lazy val cache: Parser[LogicalPlan] =
    CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
      case isLazy ~ tableName ~ plan =>
        CacheTableCommand(tableName,, isLazy.isDefined)

  private lazy val uncache: Parser[LogicalPlan] =
    ( UNCACHE ~ TABLE ~> ident ^^ {
        case tableName => UncacheTableCommand(tableName)
    | CLEAR ~ CACHE ^^^ ClearCacheCommand

  private lazy val set: Parser[LogicalPlan] =
    SET ~> restInput ^^ {
      case input => SetCommandParser(input)

  // It can be the following patterns:
  // SHOW FUNCTIONS mydb.func1;
  // SHOW FUNCTIONS func1;
  // SHOW FUNCTIONS `mydb.a`.`func1.aa`;
  private lazy val show: Parser[LogicalPlan] =
    ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ {
        case _ ~ dbName => ShowTablesCommand(dbName)
    | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ {
        case Some(f) => logical.ShowFunctions(f._1, Some(f._2))
        case None => logical.ShowFunctions(None, None)

  private lazy val desc: Parser[LogicalPlan] =
    DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ {
      case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined)

  private lazy val others: Parser[LogicalPlan] =
    wholeInput ^^ {
      case input => fallback(input)

Example 46
Source File: ComputeCurrentTimeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
Example 47
Source File: ReorderAssociativeOperatorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil

  val testRelation = LocalRelation(', ', '

  test("Reorder associative operators") {
    val originalQuery =
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)

    comparePlans(optimized, correctAnswer)

  test("nested expression with aggregate operator") {
    val originalQuery ="t1")
        .join("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
Example 48
Source File: AggregateOptimizeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      RemoveRepetitionFromGroupExpressions) :: Nil

  val testRelation = LocalRelation(', ', '

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)

  test("Remove aliased literals") {
    val query ='a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer ='a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("remove repetition in grouping expression") {
    val input = LocalRelation(', ', '
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
Example 49
Source File: SparkPlanner.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
      FileSourceStrategy ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
    plan.collect {
      case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan

  override protected def prunePlans(plans: Iterator[SparkPlan]): Iterator[SparkPlan] = {
    // TODO: We will need to prune bad plans when we improve plan space exploration
    //       to prevent combinatorial explosion.

  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet( == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]), scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList,, scan)).getOrElse(scan))
Example 50
Source File: PruneFileSourcePartitions.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule

private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
  override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
    case op @ PhysicalOperation(projects, filters,
        logicalRelation @
          LogicalRelation(fsRelation @
              catalogFileIndex: CatalogFileIndex,
        if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
      // The attribute name of predicate could be different than the one in schema in case of
      // case insensitive, we should change them to match the one in schema, so we donot need to
      // worry about case sensitivity anymore.
      val normalizedFilters = { e =>
        e transform {
          case a: AttributeReference =>

      val sparkSession = fsRelation.sparkSession
      val partitionColumns =
          partitionSchema, sparkSession.sessionState.analyzer.resolver)
      val partitionSet = AttributeSet(partitionColumns)
      val partitionKeyFilters =

      if (partitionKeyFilters.nonEmpty) {
        val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
        val prunedFsRelation =
          fsRelation.copy(location = prunedFileIndex)(sparkSession)
        val prunedLogicalRelation = logicalRelation.copy(
          relation = prunedFsRelation,
          expectedOutputAttributes = Some(logicalRelation.output))

        // Keep partition-pruning predicates so that they are visible in physical planning
        val filterExpression = filters.reduceLeft(And)
        val filter = Filter(filterExpression, prunedLogicalRelation)
        Project(projects, filter)
      } else {
Example 51
Source File: LogicalRelation.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils

  override def newInstance(): this.type = {

  override def refresh(): Unit = relation match {
    case fs: HadoopFsRelation => fs.location.refresh()
    case _ =>  // Do nothing.

  override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
Example 52
Source File: InsertIntoDataSourceCommand.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OverwriteOptions}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

case class InsertIntoDataSourceCommand(
    logicalRelation: LogicalRelation,
    query: LogicalPlan,
    overwrite: OverwriteOptions)
  extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
    val data = Dataset.ofRows(sparkSession, query)
    // Apply the schema of the existing table to the new data.
    val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
    relation.insert(df, overwrite.enabled)

    // Invalidate the cache.

Example 53
Source File: ddl.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types._

case class CreateTable(
    tableDesc: CatalogTable,
    mode: SaveMode,
    query: Option[LogicalPlan]) extends Command {
  assert(tableDesc.provider.isDefined, "The table to be created must have a provider.")

  if (query.isEmpty) {
      mode == SaveMode.ErrorIfExists || mode == SaveMode.Ignore,
      "create table without data insertion can only use ErrorIfExists or Ignore as SaveMode.")

  override def innerChildren: Seq[QueryPlan[_]] = query.toSeq

case class CreateTempViewUsing(
    tableIdent: TableIdentifier,
    userSpecifiedSchema: Option[StructType],
    replace: Boolean,
    global: Boolean,
    provider: String,
    options: Map[String, String]) extends RunnableCommand {

  if (tableIdent.database.isDefined) {
    throw new AnalysisException(
      s"Temporary view '$tableIdent' should not have specified a database")

  override def argString: String = {
    s"[tableIdent:$tableIdent " + + " ").getOrElse("") +
      s"replace:$replace " +
      s"provider:$provider " +

  def run(sparkSession: SparkSession): Seq[Row] = {
    val dataSource = DataSource(
      userSpecifiedSchema = userSpecifiedSchema,
      className = provider,
      options = options)

    val catalog = sparkSession.sessionState.catalog
    val viewDefinition = Dataset.ofRows(
      sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan

    if (global) {
      catalog.createGlobalTempView(tableIdent.table, viewDefinition, replace)
    } else {
      catalog.createTempView(tableIdent.table, viewDefinition, replace)


case class RefreshTable(tableIdent: TableIdentifier)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    // Refresh the given table's metadata. If this table is cached as an InMemoryRelation,
    // drop the original cached version and make the new version cached lazily.

case class RefreshResource(path: String)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
Example 54
Source File: cache.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

case class CacheTableCommand(
    tableIdent: TableIdentifier,
    plan: Option[LogicalPlan],
    isLazy: Boolean) extends RunnableCommand {
  require(plan.isEmpty || tableIdent.database.isEmpty,
    "Database name is not allowed in CACHE TABLE AS SELECT")

  override protected def innerChildren: Seq[QueryPlan[_]] = {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    plan.foreach { logicalPlan =>
      Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)

    if (!isLazy) {
      // Performs eager caching


case class UncacheTableCommand(
    tableIdent: TableIdentifier,
    ifExists: Boolean) extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val tableId = tableIdent.quotedString
    try {
    } catch {
      case _: NoSuchTableException if ifExists => // don't throw

case object ClearCacheCommand extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
Example 55
Source File: commands.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.streaming.IncrementalExecution
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._

case class ExplainCommand(
    logicalPlan: LogicalPlan,
    override val output: Seq[Attribute] =
      Seq(AttributeReference("plan", StringType, nullable = true)()),
    extended: Boolean = false,
    codegen: Boolean = false)
  extends RunnableCommand {

  // Run through the optimizer to generate the physical plan.
  override def run(sparkSession: SparkSession): Seq[Row] = try {
    val queryExecution =
      if (logicalPlan.isStreaming) {
        // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
        // output mode does not matter since there is no `Sink`.
        new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, 0)
      } else {
    val outputString =
      if (codegen) {
      } else if (extended) {
      } else {
  } catch { case cause: TreeNodeException[_] =>
    ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
Example 56
Source File: QueryExecutionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext

class QueryExecutionSuite extends SharedSQLContext {
  test("toString() exception/error handling") {
    val badRule = new SparkStrategy {
      var mode: String = ""
      override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match {
        case "exception" => throw new AnalysisException(mode)
        case "error" => throw new Error(mode)
        case _ => Nil
    spark.experimental.extraStrategies = badRule :: Nil

    def qe: QueryExecution = new QueryExecution(spark, OneRowRelation)

    // Nothing!
    badRule.mode = ""

    // Throw an AnalysisException - this should be captured.
    badRule.mode = "exception"

    // Throw an Error - this should not be captured.
    badRule.mode = "error"
    val error = intercept[Error](qe.toString)
Example 57
Source File: SparkPlannerSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec( :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
Example 58
Source File: ExtraStrategiesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
        Row("so fast"))

      checkAnswer("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
Example 59
Source File: package.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up

import scala.language.implicitConversions

import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import{JoinEdge, ModularPlan}

package object dsl {

  object Plans {

    implicit class DslModularPlan(val modularPlan: ModularPlan) {
      def select(outputExprs: NamedExpression*)
        (inputExprs: Expression*)
        (predicateExprs: Expression*)
        (aliasMap: Map[Int, String])
        (joinEdges: JoinEdge*): ModularPlan = {

      def groupBy(outputExprs: NamedExpression*)
        (inputExprs: Expression*)
        (predicateExprs: Expression*): ModularPlan = {
          .GroupBy(outputExprs, inputExprs, predicateExprs, None, modularPlan, NoFlags, Seq.empty)

      def harmonize: ModularPlan = modularPlan.harmonized

    implicit class DslModularPlans(val modularPlans: Seq[ModularPlan]) {
      def select(outputExprs: NamedExpression*)
        (inputExprs: Expression*)
        (predicateList: Expression*)
        (aliasMap: Map[Int, String])
        (joinEdges: JoinEdge*): ModularPlan = {

      def union(): ModularPlan = modular.Union(modularPlans, NoFlags, Seq.empty)

    implicit class DslLogical2Modular(val logicalPlan: LogicalPlan) {
      def resolveonly: LogicalPlan = analysis.SimpleAnalyzer.execute(logicalPlan)

      def modularize: ModularPlan = modular.SimpleModularizer.modularize(logicalPlan).next

      def optimize: LogicalPlan = BirdcageOptimizer.execute(logicalPlan)


Example 60
Source File: package.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, PredicateHelper, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import{CheckSPJG, LogicalPlanSignatureGenerator, Signature}

    def canEvaluate(exp: ScalaUDF, exprList: Seq[Expression]): Boolean = {
      var canBeDerived = false
      exprList.forall {
        case udf: ScalaUDF =>
          if (udf.children.length == exp.children.length) {
            if ( => e._1.sql.equalsIgnoreCase(e._2.sql))) {
              canBeDerived = true
        case _ =>

    def canEvaluate(expr: Expression, exprList: Seq[Expression]): Boolean = {
      expr match {
        case exp: ScalaUDF =>
          canEvaluate(exp, exprList)
        case _ =>

  def supports(supported: Boolean, message: Any) {
    if (!supported) {
      throw new UnsupportedOperationException(s"unsupported operation: $message")
Example 61
Source File: CarbonMVRules.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.CarbonReflectionUtils

case class CarbonMVRules(sparkSession: SparkSession) extends Rule[LogicalPlan] {

  val mvPlan = try {
  } catch {
    case e: Exception =>

  override def apply(plan: LogicalPlan): LogicalPlan = {
    plan match {
      case _: Command => plan
      case _ =>
        if (mvPlan != null) {
        } else {
Example 62
Source File: CarbonSqlAstBuilder.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.parser.ParserUtils.string
import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{AddTableColumnsContext, CreateHiveTableContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.parser.{CarbonHelperSqlAstBuilder, CarbonSpark2SqlParser, CarbonSparkSqlParserUtil}

class CarbonSqlAstBuilder(conf: SQLConf, parser: CarbonSpark2SqlParser, sparkSession: SparkSession)
  extends SparkSqlAstBuilder(conf) with SqlAstBuilderHelper {

  val helper = new CarbonHelperSqlAstBuilder(conf, parser, sparkSession)

  override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = {
    val fileStorage = CarbonSparkSqlParserUtil.getFileStorage(ctx.createFileFormat(0))

    if (fileStorage.equalsIgnoreCase("'carbondata'") ||
        fileStorage.equalsIgnoreCase("carbondata") ||
        fileStorage.equalsIgnoreCase("'carbonfile'") ||
        fileStorage.equalsIgnoreCase("'org.apache.carbondata.format'")) {
      val createTableTuple = (ctx.createTableHeader, ctx.skewSpec(0),
        ctx.bucketSpec(0), ctx.partitionColumns, ctx.columns, ctx.tablePropertyList(0),
        ctx.locationSpec(0), Option(ctx.STRING(0)).map(string), ctx.AS, ctx.query, fileStorage)
    } else {

  override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = {
    visitAddTableColumns(parser, ctx)
Example 63
Source File: CarbonAnalyzer.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CarbonReflectionUtils

class CarbonAnalyzer(catalog: SessionCatalog,
    conf: SQLConf,
    sparkSession: SparkSession,
    analyzer: Analyzer) extends Analyzer(catalog, conf) {

  val mvPlan = try {
  } catch {
    case e: Exception =>

  override def execute(plan: LogicalPlan): LogicalPlan = {
    val logicalPlan = analyzer.execute(plan)
    if (mvPlan != null) {
    } else {
Example 64
Source File: CarbonResetCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution.command

import org.apache.spark.sql.{CarbonEnv, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.{ResetCommand, RunnableCommand}

case class CarbonResetCommand()
  extends RunnableCommand {
  override val output = ResetCommand.output

  override def run(sparkSession: SparkSession): Seq[Row] = {

object MatchResetCommand {
  def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
    plan match {
      case r@ResetCommand =>
      case _ =>
Example 65
Source File: SqlAstBuilderHelper.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.catalyst.parser.ParserUtils.{string, withOrigin}
import org.apache.spark.sql.catalyst.parser.SqlBaseParser
import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{AddTableColumnsContext, ChangeColumnContext, CreateTableContext, ShowTablesContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.CarbonParserUtil
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.execution.command.{AlterTableAddColumnsModel, AlterTableDataTypeChangeModel}
import org.apache.spark.sql.execution.command.schema.{CarbonAlterTableAddColumnCommand, CarbonAlterTableColRenameDataTypeChangeCommand}
import org.apache.spark.sql.execution.command.table.{CarbonExplainCommand, CarbonShowTablesCommand}
import org.apache.spark.sql.parser.CarbonSpark2SqlParser
import org.apache.spark.sql.types.DecimalType

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties

trait SqlAstBuilderHelper extends SparkSqlAstBuilder {

  override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = {

    val newColumn = visitColType(ctx.colType)
    val isColumnRename = if (!ctx.identifier.getText.equalsIgnoreCase( {
    } else {

    val (typeString, values): (String, Option[List[(Int, Int)]]) = newColumn.dataType match {
      case d: DecimalType => ("decimal", Some(List((d.precision, d.scale))))
      case _ => (newColumn.dataType.typeName.toLowerCase, None)

    val alterTableColRenameAndDataTypeChangeModel =
        CarbonParserUtil.parseDataType(typeString, values, isColumnRename),


  def visitAddTableColumns(parser: CarbonSpark2SqlParser,
      ctx: AddTableColumnsContext): LogicalPlan = {
    val cols = Option(ctx.columns).toSeq.flatMap(visitColTypeList)
    val fields = parser.getFields(cols)
    val tblProperties = scala.collection.mutable.Map.empty[String, String]
    val tableModel = CarbonParserUtil.prepareTableModel(false,

    val alterTableAddColumnsModel = AlterTableAddColumnsModel(


  override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = {

  override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = {
    withOrigin(ctx) {

  override def visitExplain(ctx: SqlBaseParser.ExplainContext): LogicalPlan = {
Example 66
Source File: CarbonExtensions.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.strategy.{CarbonLateDecodeStrategy, DDLStrategy, StreamingTableStrategy}
import org.apache.spark.sql.hive.{CarbonIUDAnalysisRule, CarbonPreInsertionCasts}
import org.apache.spark.sql.parser.CarbonExtensionSqlParser

class CarbonExtensions extends (SparkSessionExtensions => Unit) {

  override def apply(extensions: SparkSessionExtensions): Unit = {
    // Carbon internal parser
      .injectParser((sparkSession: SparkSession, parser: ParserInterface) =>
        new CarbonExtensionSqlParser(new SQLConf, sparkSession, parser))

    // carbon analyzer rules
      .injectResolutionRule((session: SparkSession) => CarbonIUDAnalysisRule(session))
      .injectResolutionRule((session: SparkSession) => CarbonPreInsertionCasts(session))

    // carbon optimizer rules
    extensions.injectPostHocResolutionRule((session: SparkSession) => CarbonOptimizerRule(session))

    // carbon planner strategies
      .injectPlannerStrategy((session: SparkSession) => new StreamingTableStrategy(session))
      .injectPlannerStrategy((_: SparkSession) => new CarbonLateDecodeStrategy)
      .injectPlannerStrategy((session: SparkSession) => new DDLStrategy(session))

    // init CarbonEnv

case class CarbonOptimizerRule(session: SparkSession) extends Rule[LogicalPlan] {
  self =>

  var notAdded = true

  override def apply(plan: LogicalPlan): LogicalPlan = {
    if (notAdded) {
      self.synchronized {
        if (notAdded) {
          notAdded = false

          val sessionState = session.sessionState
          val field = sessionState.getClass.getDeclaredField("optimizer")
            new CarbonOptimizer(session, sessionState.catalog, sessionState.optimizer))
Example 67
Source File: CarbonExtensionSqlParser.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.parser

import org.apache.spark.sql.{CarbonEnv, CarbonUtils, SparkSession}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CarbonException

import org.apache.carbondata.common.exceptions.sql.MalformedCarbonCommandException
import org.apache.carbondata.spark.util.CarbonScalaUtil

class CarbonExtensionSqlParser(
    conf: SQLConf,
    sparkSession: SparkSession,
    initialParser: ParserInterface
) extends SparkSqlParser(conf) {

  val parser = new CarbonExtensionSpark2SqlParser

  override def parsePlan(sqlText: String): LogicalPlan = {
    parser.synchronized {
    try {
      val plan = parser.parse(sqlText)
    } catch {
      case ce: MalformedCarbonCommandException =>
        throw ce
      case ex: Throwable =>
        try {
          val parsedPlan = initialParser.parsePlan(sqlText)
        } catch {
          case mce: MalformedCarbonCommandException =>
            throw mce
          case e: Throwable =>
              s"""== Parser1: ${parser.getClass.getName} ==
                 |== Parser2: ${initialParser.getClass.getName} ==
Example 68
Source File: StreamingTableStrategy.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.strategy

import org.apache.spark.sql.{CarbonEnv, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableChangeColumnCommand, AlterTableRenameCommand}
import org.apache.spark.sql.execution.command.mutation.{CarbonProjectForDeleteCommand, CarbonProjectForUpdateCommand}
import org.apache.spark.sql.execution.command.schema.{CarbonAlterTableAddColumnCommand, CarbonAlterTableColRenameDataTypeChangeCommand, CarbonAlterTableDropColumnCommand}

import org.apache.carbondata.common.exceptions.sql.MalformedCarbonCommandException

  private def rejectIfStreamingTable(tableIdentifier: TableIdentifier, operation: String): Unit = {
    var streaming = false
    try {
      streaming = CarbonEnv.getCarbonTable(
        tableIdentifier.database, tableIdentifier.table)(sparkSession)
    } catch {
      case e: Exception =>
        streaming = false
    if (streaming) {
      throw new MalformedCarbonCommandException(
        s"$operation is not allowed for streaming table")

  def isCarbonTable(tableIdent: TableIdentifier): Boolean = {
    CarbonPlanHelper.isCarbonTable(tableIdent, sparkSession)
Example 69
Source File: CarbonCreateTableAsSelectCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.table

import org.apache.spark.sql.{CarbonEnv, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.AtomicRunnableCommand

import org.apache.carbondata.core.metadata.schema.table.TableInfo

case class CarbonCreateTableAsSelectCommand(
    tableInfo: TableInfo,
    query: LogicalPlan,
    ifNotExistsSet: Boolean = false,
    tableLocation: Option[String] = None) extends AtomicRunnableCommand {

  var insertIntoCommand: CarbonInsertIntoCommand = _

  override def processMetadata(sparkSession: SparkSession): Seq[Row] = {
    val tableName = tableInfo.getFactTable.getTableName
    var isTableCreated = false
    var databaseOpt: Option[String] = None
    if (tableInfo.getDatabaseName != null) {
      databaseOpt = Some(tableInfo.getDatabaseName)
    val dbName = CarbonEnv.getDatabaseName(databaseOpt)(sparkSession)
    setAuditTable(dbName, tableName)
    setAuditInfo(Map("query" -> query.simpleString))
    // check if table already exists
    if (sparkSession.sessionState.catalog
      .tableExists(TableIdentifier(tableName, Some(dbName)))) {
      if (!ifNotExistsSet) {
        throw new TableAlreadyExistsException(dbName, tableName)
    } else {
      // execute command to create carbon table
      CarbonCreateTableCommand(tableInfo, ifNotExistsSet, tableLocation).run(sparkSession)
      isTableCreated = true

    if (isTableCreated) {
      val tableName = tableInfo.getFactTable.getTableName
      var databaseOpt: Option[String] = None
      if (tableInfo.getDatabaseName != null) {
        databaseOpt = Some(tableInfo.getDatabaseName)
      val dbName = CarbonEnv.getDatabaseName(databaseOpt)(sparkSession)
      val carbonDataSourceHadoopRelation = CarbonEnv.getInstance(sparkSession).carbonMetaStore
          TableIdentifier(tableName, Option(dbName)))
      // execute command to load data into carbon table
      insertIntoCommand = CarbonInsertIntoCommand(
        databaseNameOp = Some(carbonDataSourceHadoopRelation.carbonRelation.databaseName),
        tableName = carbonDataSourceHadoopRelation.carbonRelation.tableName,
        options = scala.collection.immutable
          .Map("fileheader" ->
        isOverwriteTable = false,
        logicalPlan = query,
        tableInfo = tableInfo)

  override def processData(sparkSession: SparkSession): Seq[Row] = {
    if (null != insertIntoCommand) {

  override def undoMetadata(sparkSession: SparkSession, exception: Exception): Seq[Row] = {
    val tableName = tableInfo.getFactTable.getTableName
    var databaseOpt: Option[String] = None
    if (tableInfo.getDatabaseName != null) {
      databaseOpt = Some(tableInfo.getDatabaseName)
    val dbName = CarbonEnv.getDatabaseName(databaseOpt)(sparkSession)
    // drop the created table.
      ifExistsSet = false,
      Option(dbName), tableName).run(sparkSession)

  override protected def opName: String = "CREATE TABLE AS SELECT"
Source File: CarbonExplainCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.table

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, Union}
import org.apache.spark.sql.execution.command.{ExplainCommand, MetadataCommand}
import org.apache.spark.sql.types.StringType

import org.apache.carbondata.core.profiler.ExplainCollector

case class CarbonExplainCommand(
    child: LogicalPlan,
    override val output: Seq[Attribute] =
    Seq(AttributeReference("plan", StringType, nullable = true)()))
  extends MetadataCommand {

  override def processMetadata(sparkSession: SparkSession): Seq[Row] = {
    val explainCommand = child.asInstanceOf[ExplainCommand]
    setAuditInfo(Map("query" -> explainCommand.logicalPlan.simpleString))
    val isCommand = explainCommand.logicalPlan match {
      case _: Command => true
      case Union(childern) if childern.forall(_.isInstanceOf[Command]) => true
      case _ => false

    if (explainCommand.logicalPlan.isStreaming || isCommand) {
    } else {
      CarbonExplainCommand.collectProfiler(explainCommand, sparkSession) ++

  override protected def opName: String = "EXPLAIN"

case class CarbonInternalExplainCommand(
    explainCommand: ExplainCommand,
    override val output: Seq[Attribute] =
    Seq(AttributeReference("plan", StringType, nullable = true)()))
  extends MetadataCommand {

  override def processMetadata(sparkSession: SparkSession): Seq[Row] = {
      .collectProfiler(explainCommand, sparkSession) ++

  override protected def opName: String = "Carbon EXPLAIN"

object CarbonExplainCommand {
  def collectProfiler(
      explain: ExplainCommand,
      sparkSession: SparkSession): Seq[Row] = {
    try {
      if (ExplainCollector.enabled()) {
        val queryExecution =
        // For count(*) queries the explain collector will be disabled, so profiler
        // informations not required in such scenarios.
        if (null == ExplainCollector.getFormatedOutput) {
        Seq(Row("== CarbonData Profiler ==\n" + ExplainCollector.getFormatedOutput))
      } else {
    } finally {
Source File: IUDCommonUtil.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.mutation

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.hive.HiveSessionCatalog

import org.apache.carbondata.common.exceptions.sql.MalformedCarbonCommandException
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties

  def checkIfSegmentListIsSet(sparkSession: SparkSession, logicalPlan: LogicalPlan): Unit = {
    val carbonProperties = CarbonProperties.getInstance()
    logicalPlan.foreach {
      case unresolvedRelation: UnresolvedRelation =>
        val dbAndTb =
          sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog].getCurrentDatabase +
          "." + unresolvedRelation.tableIdentifier.table
        val segmentProperties = carbonProperties
          .getProperty(CarbonCommonConstants.CARBON_INPUT_SEGMENTS + dbAndTb, "")
        if (!(segmentProperties.equals("") || segmentProperties.trim.equals("*"))) {
          throw new MalformedCarbonCommandException("carbon.input.segments." + dbAndTb +
                                                    "should not be set for table used in DELETE " +
                                                    "query. Please reset the property to carbon" +
                                                    ".input.segments." +
                                                    dbAndTb + "=*")
      case logicalRelation: LogicalRelation if (logicalRelation.relation
        .isInstanceOf[CarbonDatasourceHadoopRelation]) =>
        val dbAndTb =
            .getDatabaseName + "." +
        val sementProperty = carbonProperties
          .getProperty(CarbonCommonConstants.CARBON_INPUT_SEGMENTS + dbAndTb, "")
        if (!(sementProperty.equals("") || sementProperty.trim.equals("*"))) {
          throw new MalformedCarbonCommandException("carbon.input.segments." + dbAndTb +
                                                    "should not be set for table used in UPDATE " +
                                                    "query. Please reset the property to carbon" +
                                                    ".input.segments." +
                                                    dbAndTb + "=*")
      case filter: Filter => filter.subqueries.toList
        .foreach(subquery => checkIfSegmentListIsSet(sparkSession, subquery))
      case _ =>
Source File: CarbonFileIndexReplaceRule.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.carbondata.execution.datasources

import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, InMemoryFileIndex, InsertIntoHadoopFsRelationCommand, LogicalRelation}
import org.apache.spark.sql.sources.BaseRelation

import org.apache.carbondata.core.datastore.filesystem.CarbonFile
import org.apache.carbondata.core.datastore.impl.FileFactory
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.carbondata.core.util.path.CarbonTablePath

  private def getDataFolders(
      tableFolder: CarbonFile,
      dataFolders: ArrayBuffer[CarbonFile]): Unit = {
    val files = tableFolder.listFiles()
    files.foreach { f =>
      if (f.isDirectory) {
        val files = f.listFiles()
        if (files.nonEmpty && !files(0).isDirectory) {
          dataFolders += f
        } else {
          getDataFolders(f, dataFolders)
Source File: CarbonUDFTransformRule.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, PredicateHelper,
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.types.StringType

import org.apache.carbondata.core.constants.CarbonCommonConstants

class CarbonUDFTransformRule extends Rule[LogicalPlan] with PredicateHelper {
  override def apply(plan: LogicalPlan): LogicalPlan = {

  private def pushDownUDFToJoinLeftRelation(plan: LogicalPlan): LogicalPlan = {
    val output = plan.transform {
      case proj@Project(cols, Join(
      left, right, jointype: org.apache.spark.sql.catalyst.plans.JoinType, condition)) =>
        var projectionToBeAdded: Seq[org.apache.spark.sql.catalyst.expressions.Alias] = Seq.empty
        var udfExists = false
        val newCols = {
          case a@Alias(s: ScalaUDF, name)
            if name.equalsIgnoreCase(CarbonCommonConstants.POSITION_ID) ||
               name.equalsIgnoreCase(CarbonCommonConstants.CARBON_IMPLICIT_COLUMN_TUPLEID) =>
            udfExists = true
            projectionToBeAdded :+= a
            AttributeReference(name, StringType, nullable = true)().withExprId(a.exprId)
          case other => other
        if (udfExists) {
          val newLeft = left match {
            case Project(columns, logicalPlan) =>
              Project(columns ++ projectionToBeAdded, logicalPlan)
            case filter: Filter =>
              Project(filter.output ++ projectionToBeAdded, filter)
            case relation: LogicalRelation =>
              Project(relation.output ++ projectionToBeAdded, relation)
            case other => other
          Project(newCols, Join(newLeft, right, jointype, condition))
        } else {
      case other => other

Source File: CarbonIUDRule.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.optimizer

import org.apache.spark.sql.ProjectForUpdate
import org.apache.spark.sql.catalyst.expressions.{NamedExpression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.mutation.CarbonProjectForUpdateCommand

import org.apache.carbondata.core.constants.CarbonCommonConstants

class CarbonIUDRule extends Rule[LogicalPlan] with PredicateHelper {
  override def apply(plan: LogicalPlan): LogicalPlan = {

  private def processPlan(plan: LogicalPlan): LogicalPlan = {
    plan transform {
      case ProjectForUpdate(table, cols, Seq(updatePlan)) =>
        var isTransformed = false
        val newPlan = updatePlan transform {
          case Project(pList, child) if !isTransformed =>
            var (dest: Seq[NamedExpression], source: Seq[NamedExpression]) = pList
              .splitAt(pList.size - cols.size)
            // check complex column
            cols.foreach { col =>
              val complexExists = "\"name\":\"" + col + "\""
              if (dest.exists(m => m.dataType.json.contains(complexExists))) {
                throw new UnsupportedOperationException(
                  "Unsupported operation on Complex data type")
            // check updated columns exists in table
            val diff = cols.diff(
            if (diff.nonEmpty) {
              sys.error(s"Unknown column(s) ${ diff.mkString(",") } in table ${ table.tableName }")
            // modify plan for updated column *in place*
            isTransformed = true
            source.foreach { col =>
              val colName =,
              val updateIdx = dest.indexWhere(
              dest = dest.updated(updateIdx, col)
            Project(dest, child)
          newPlan, table.tableIdentifier.database, table.tableIdentifier.table, cols)
Example 75
Source File: CarbonExpressions.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.types.DataType

  object CarbonScalaUDF {
    def unapply(expression: Expression): Option[(ScalaUDF)] = {
      expression match {
        case a: ScalaUDF =>
        case _ =>
Example 76
Source File: MVCoalesceTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.view.rewrite

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

class MVCoalesceTestCase  extends QueryTest with BeforeAndAfterAll  {
  override def beforeAll(): Unit = {
    sql("create table coalesce_test_main(id int,name string,height int,weight int) " +
      "using carbondata")
    sql("insert into coalesce_test_main select 1,'tom',170,130")
    sql("insert into coalesce_test_main select 2,'tom',170,120")
    sql("insert into coalesce_test_main select 3,'lily',160,100")

  def drop(): Unit = {
    sql("drop table if exists coalesce_test_main")

  test("test mv table with coalesce expression on sql not on mv and less groupby cols") {
    sql("drop materialized view if exists coalesce_test_main_mv")
    sql("create materialized view coalesce_test_main_mv as " +
      "select sum(id) as sum_id,name as myname,weight from coalesce_test_main group by name,weight")
    sql("refresh materialized view coalesce_test_main_mv")

    val frame = sql("select coalesce(sum(id),0) as sumid,name from coalesce_test_main group by name")
    assert(TestUtil.verifyMVHit(frame.queryExecution.optimizedPlan, "coalesce_test_main_mv"))
    checkAnswer(frame, Seq(Row(3, "tom"), Row(3, "lily")))

    sql("drop materialized view if exists coalesce_test_main_mv")

  test("test mv table with coalesce expression less groupby cols") {
    sql("drop materialized view if exists coalesce_test_main_mv")
    val exception: Exception = intercept[UnsupportedOperationException] {
      sql("create materialized view coalesce_test_main_mv as " +
        "select coalesce(sum(id),0) as sum_id,name as myname,weight from coalesce_test_main group by name,weight")
      sql("refresh materialized view coalesce_test_main_mv")
    assert("MV doesn't support Coalesce".equals(exception.getMessage))

    val frame = sql("select coalesce(sum(id),0) as sumid,name from coalesce_test_main group by name")
    assert(!TestUtil.verifyMVHit(frame.queryExecution.optimizedPlan, "coalesce_test_main_mv"))
    checkAnswer(frame, Seq(Row(3, "tom"), Row(3, "lily")))

    sql("drop materialized view if exists coalesce_test_main_mv")

  test("test mv table with coalesce expression in other expression") {
    sql("drop materialized view if exists coalesce_test_main_mv")
    sql("create materialized view coalesce_test_main_mv as " +
      "select sum(coalesce(id,0)) as sum_id,name as myname,weight from coalesce_test_main group by name,weight")
    sql("refresh materialized view coalesce_test_main_mv")

    val frame = sql("select sum(coalesce(id,0)) as sumid,name from coalesce_test_main group by name")
    assert(TestUtil.verifyMVHit(frame.queryExecution.optimizedPlan, "coalesce_test_main_mv"))
    checkAnswer(frame, Seq(Row(3, "tom"), Row(3, "lily")))

    sql("drop materialized view if exists coalesce_test_main_mv")

object TestUtil {
  def verifyMVHit(logicalPlan: LogicalPlan, mvName: String): Boolean = {
    val tables = logicalPlan collect {
      case l: LogicalRelation => l.catalogTable.get
Source File: ExistingDStream.scala    From spark-cep   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming

import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.streaming.Time
import org.apache.spark.streaming.dstream.DStream

case class PhysicalDStream(output: Seq[Attribute], @transient stream: DStream[InternalRow])
    extends SparkPlan with StreamPlan {

  def children = Nil

  override def doExecute() = {
    assert(validTime != null)
    Utils.invoke(classOf[DStream[InternalRow]], stream, "getOrCompute", (classOf[Time], validTime))
      .getOrElse(new EmptyRDD[InternalRow](sparkContext))
Example 78
Source File: RangerSparkOptimizer.scala    From spark-ranger   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class RangerSparkOptimizer(spark: SparkSession) extends RuleExecutor[LogicalPlan] {

  override def batches: Seq[Batch] = {
    val optimizer = spark.sessionState.optimizer
    val extRules = optimizer.extendedOperatorOptimizationRules { batch =>
      val ruleSet = batch.rules.toSet -- extRules
      Batch(, FixedPoint(batch.strategy.maxIterations), ruleSet.toSeq: _*)
Example 79
Source File: DescribeDeltaHistoryCommand.scala    From delta   with Apache License 2.0 5 votes vote down vote up

// scalastyle:off import.ordering.noEmptyLine
import{DeltaErrors, DeltaLog, DeltaTableIdentifier}
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.command.RunnableCommand

case class DescribeDeltaHistoryCommand(
    path: Option[String],
    tableIdentifier: Option[TableIdentifier],
    limit: Option[Int],
    override val output: Seq[Attribute] = ExpressionEncoder[CommitInfo]().schema.toAttributes)
  extends RunnableCommand with DeltaLogging {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val basePath =
      if (path.nonEmpty) {
        new Path(path.get)
      } else if (tableIdentifier.nonEmpty) {
        val sessionCatalog = sparkSession.sessionState.catalog
        lazy val metadata = sessionCatalog.getTableMetadata(tableIdentifier.get)

        DeltaTableIdentifier(sparkSession, tableIdentifier.get) match {
          case Some(id) if id.path.nonEmpty =>
            new Path(id.path.get)
          case Some(id) if id.table.nonEmpty =>
            new Path(metadata.location)
          case _ =>
            if (metadata.tableType == CatalogTableType.VIEW) {
              throw DeltaErrors.describeViewHistory
            throw DeltaErrors.notADeltaTableException("DESCRIBE HISTORY")
      } else {
        throw DeltaErrors.missingTableIdentifierException("DESCRIBE HISTORY")

    // Max array size
    if (limit.exists(_ > Int.MaxValue - 8)) {
      throw new IllegalArgumentException("Please use a limit less than Int.MaxValue - 8.")

    val deltaLog = DeltaLog.forTable(sparkSession, basePath)
    recordDeltaOperation(deltaLog, "delta.ddl.describeHistory") {
      if (deltaLog.snapshot.version == -1) {
        throw DeltaErrors.notADeltaTableException("DESCRIBE HISTORY")

      import sparkSession.implicits._
Example 80
Source File: PreprocessTableDelete.scala    From delta   with Apache License 2.0 5 votes vote down vote up


import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{DeltaDelete, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf

case class PreprocessTableDelete(conf: SQLConf) extends Rule[LogicalPlan] {

  override def apply(plan: LogicalPlan): LogicalPlan = {
    plan.resolveOperators {
      case d: DeltaDelete if d.resolved =>
        d.condition.foreach { cond =>
          if (SubqueryExpression.hasSubquery(cond)) {
            throw DeltaErrors.subqueryNotSupportedException("DELETE", cond)

  def toCommand(d: DeltaDelete): DeleteCommand = EliminateSubqueryAliases(d.child) match {
    case DeltaFullTable(tahoeFileIndex) =>
      DeleteCommand(tahoeFileIndex, d.child, d.condition)

    case o =>
      throw DeltaErrors.notADeltaSourceException("DELETE", Some(o))
Example 81
Source File: DeltaUnsupportedOperationsCheck.scala    From delta   with Apache License 2.0 5 votes vote down vote up


import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, V2WriteCommand}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.RefreshTable
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

  private def checkDeltaTableExists(command: V2WriteCommand, operation: String): Unit = {
    command.table match {
      case DeltaRelation(lr) =>
        // the extractor performs the check that we want if this is indeed being called on a Delta
        // table. It should leave others unchanged
        if (DeltaFullTable.unapply(lr).isEmpty) {
          throw DeltaErrors.notADeltaTableException(operation)
      case _ =>
Example 82
Source File: AnalysisHelper.scala    From delta   with Apache License 2.0 5 votes vote down vote up


import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

trait AnalysisHelper {
  import AnalysisHelper._

  protected def tryResolveReferences(
      sparkSession: SparkSession)(
      expr: Expression,
      planContainingExpr: LogicalPlan): Expression = {
    val newPlan = FakeLogicalPlan(expr, planContainingExpr.children)
    sparkSession.sessionState.analyzer.execute(newPlan) match {
      case FakeLogicalPlan(resolvedExpr, _) =>
        // Return even if it did not successfully resolve
        return resolvedExpr
      case _ =>
        // This is unexpected
        throw DeltaErrors.analysisException(
          s"Could not resolve expression $expr", plan = Option(planContainingExpr))

  protected def toDataset(sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[Row] = {
    Dataset.ofRows(sparkSession, logicalPlan)

  protected def improveUnsupportedOpError(f: => Unit): Unit = {
    val possibleErrorMsgs = Seq(
      "is only supported with v2 table", // full error: DELETE is only supported with v2 tables
      "is not supported temporarily",    // full error: UPDATE TABLE is not supported temporarily
      "Table does not support read",
      "Table implementation does not support writes"

    def isExtensionOrCatalogError(error: Exception): Boolean = {
      possibleErrorMsgs.exists(m => error.getMessage().toLowerCase().contains(m))

    try { f } catch {
      case e: Exception if isExtensionOrCatalogError(e) =>
        throw DeltaErrors.configureSparkSessionWithExtensionAndCatalog(e)

object AnalysisHelper {
  case class FakeLogicalPlan(expr: Expression, children: Seq[LogicalPlan])
    extends LogicalPlan {
    override def output: Seq[Attribute] = Nil
Example 83
Source File: DruidPlannerHelper.scala    From spark-druid-olap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.druid

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.types.{DecimalType, _}
import org.sparklinedata.druid.DruidOperatorAttribute

trait DruidPlannerHelper {

  def unalias(e: Expression, agg: Aggregate): Option[Expression] = {

    agg.aggregateExpressions.find { aE =>
      (aE, e) match {
        case _ if aE == e => true
        case (_, e:AttributeReference) if e.exprId == aE.exprId => true
        case (Alias(child, _), e) if child == e => true
        case _ => false
    }.map {
      case Alias(child, _) => child
      case x => x


  def findAttribute(e: Expression): Option[AttributeReference] = {

  def positionOfAttribute(e: Expression,
                          plan: LogicalPlan): Option[(Expression, (AttributeReference, Int))] = {
    for (aR <- findAttribute(e);
         attr <- plan.output.zipWithIndex.find(t => t._1.exprId == aR.exprId))
      yield (e, (aR, attr._2))

  def exprIdToAttribute(e: Expression,
                          plan: LogicalPlan): Option[(ExprId, Int)] = {
    for (aR <- findAttribute(e);
         attr <- plan.output.zipWithIndex.find(t => t._1.exprId == aR.exprId))
      yield (aR.exprId, attr._2)

  case class GroupingInfo(gEs: Seq[Expression],
                          expandOpGExps : Seq[Expression],
                          aEs: Seq[NamedExpression],
                          expandOpProjection : Seq[Expression],
                          aEExprIdToPos : Map[ExprId, Int],
                          aEToLiteralExpr: Map[Expression, Expression] = Map())

  def isNumericType(dt : DataType) : Boolean = NumericType.acceptsType(dt)

Example 84
Source File: HiveStrategies.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.hive.execution._

private[hive] trait HiveStrategies {
  // Possibly being too clever with types here... or not clever enough.
  self: SparkPlanner =>

  val sparkSession: SparkSession

  object Scripts extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.ScriptTransformation(input, script, output, child, ioschema) =>
        val hiveIoSchema = HiveScriptIOSchema(ioschema)
          input, script, output, planLater(child, sparkContext.sparkUser), hiveIoSchema) :: Nil
      case _ => Nil

  object DataSinks extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.InsertIntoTable(
          table: MetastoreRelation, partition, child, overwrite, ifNotExists) =>
          planLater(child, sparkContext.sparkUser),
          ifNotExists) :: Nil

      case CreateTable(tableDesc, mode, Some(query)) if tableDesc.provider.get == "hive" =>
        val newTableDesc = if ( {
          // add default serde
            serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
        } else {

        // Currently we will never hit this branch, as SQL string API can only use `Ignore` or
        // `ErrorIfExists` mode, and `DataFrameWriter.saveAsTable` doesn't support hive serde
        // tables yet.
        if (mode == SaveMode.Append || mode == SaveMode.Overwrite) {
          throw new AnalysisException(
            "CTAS for hive serde tables does not support append or overwrite semantics.")

        val dbName = tableDesc.identifier.database.getOrElse(sparkSession.catalog.currentDatabase)
        val cmd = CreateHiveTableAsSelectCommand(
          newTableDesc.copy(identifier = tableDesc.identifier.copy(database = Some(dbName))),
          mode == SaveMode.Ignore)
        ExecutedCommandExec(cmd) :: Nil

      case _ => Nil

  object HiveTableScans extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) =>
        // Filter out all predicates that only deal with partition keys, these are given to the
        // hive table scan operator to be used for partition pruning.
        val partitionKeyIds = AttributeSet(relation.partitionKeys)
        val (pruningPredicates, otherPredicates) = predicates.partition { predicate =>
          !predicate.references.isEmpty &&

          HiveTableScanExec(_, relation, pruningPredicates)(sparkSession)) :: Nil
      case _ =>
Example 85
Source File: CreateHiveTableAsSelectCommand.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, OverwriteOptions}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.hive.MetastoreRelation

case class CreateHiveTableAsSelectCommand(
    tableDesc: CatalogTable,
    query: LogicalPlan,
    ignoreIfExists: Boolean)
  extends RunnableCommand {

  private val tableIdentifier = tableDesc.identifier

  override def innerChildren: Seq[LogicalPlan] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    lazy val metastoreRelation: MetastoreRelation = {
      import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
      import org.apache.hadoop.mapred.TextInputFormat

      val withFormat =
          inputFormat =
          outputFormat =
              .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)),
          serde =[LazySimpleSerDe].getName)),
          compressed =

      val withSchema = if (withFormat.schema.isEmpty) {
        // Hive doesn't support specifying the column list for target table in CTAS
        // However we don't think SparkSQL should follow that.
        tableDesc.copy(schema = query.output.toStructType)
      } else {

      sparkSession.sessionState.catalog.createTable(withSchema, ignoreIfExists = false)

      // Get the Metastore Relation
      sparkSession.sessionState.catalog.lookupRelation(tableIdentifier) match {
        case r: MetastoreRelation => r
    // TODO ideally, we should get the output data ready first and then
    // add the relation into catalog, just in case of failure occurs while data
    // processing.
    if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) {
      if (ignoreIfExists) {
        // table already exists, will do nothing, to keep consistent with Hive
      } else {
        throw new AnalysisException(s"$tableIdentifier already exists.")
    } else {
      try {
          metastoreRelation, Map(), query, overwrite = OverwriteOptions(true),
          ifNotExists = false)).toRdd
      } catch {
        case NonFatal(e) =>
          // drop the created table.
          sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true,
            purge = false)
          throw e


  override def argString: String = {
    s"[Database:${tableDesc.database}}, " +
    s"TableName: ${tableDesc.identifier.table}, " +
Example 86
Source File: PruneFileSourcePartitionsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil

  test("PruneFileSourcePartitions should not change the output of LogicalRelation") {
    withTable("test") {
      withTempDir { dir =>
            |CREATE EXTERNAL TABLE test(i int)
            |PARTITIONED BY (p int)
            |STORED AS parquet
            |LOCATION '${dir.getAbsolutePath}'""".stripMargin)

        val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
        val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)

        val dataSchema = StructType(tableMeta.schema.filterNot { f =>
        val relation = HadoopFsRelation(
          location = catalogFileIndex,
          partitionSchema = tableMeta.partitionSchema,
          dataSchema = dataSchema,
          bucketSpec = None,
          fileFormat = new ParquetFileFormat(),
          options = Map.empty)(sparkSession = spark)

        val logicalRelation = LogicalRelation(relation, catalogTable = Some(tableMeta))
        val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze

        val optimized = Optimize.execute(query)
Example 87
Source File: SQLBuilderTest.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import scala.util.control.NonFatal

import org.apache.spark.sql.{DataFrame, Dataset, QueryTest}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.hive.test.TestHiveSingleton

abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
  protected def checkSQL(e: Expression, expectedSQL: String): Unit = {
    val actualSQL = e.sql
    try {
      assert(actualSQL === expectedSQL)
    } catch {
      case cause: Throwable =>
          s"""Wrong SQL generated for the following expression:

  protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = {
    val generatedSQL = try new SQLBuilder(plan).toSQL catch { case NonFatal(e) =>
        s"""Cannot convert the following logical query plan to SQL:

    try {
      assert(generatedSQL === expectedSQL)
    } catch {
      case cause: Throwable =>
          s"""Wrong SQL generated for the following logical query plan:

    checkAnswer(spark.sql(generatedSQL), Dataset.ofRows(spark, plan))

  protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
    checkSQL(df.queryExecution.analyzed, expectedSQL)
Example 88
Source File: SubstituteUnresolvedOrdinals.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
Example 89
Source File: ResolveInlineTables.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}

  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = { case (column, name) =>
      val inputTypes =
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      StructField(name, tpe, nullable = column.exists(_.nullable))
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = { row =>
      InternalRow.fromSeq( { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
          } else {
            Cast(e, targetType).eval()
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")

    LocalRelation(attributes, newRows)
Example 90
Source File: AnalysisException.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

class AnalysisException protected[sql] (
    val message: String,
    val line: Option[Int] = None,
    val startPosition: Option[Int] = None,
    // Some plans fail to serialize due to bugs in scala collections.
    @transient val plan: Option[LogicalPlan] = None,
    val cause: Option[Throwable] = None)
  extends Exception(message, cause.orNull) with Serializable {

  def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = {
    val newException = new AnalysisException(message, line, startPosition)

  override def getMessage: String = {
    val planAnnotation = => s";\n$p").getOrElse("")
    getSimpleMessage + planAnnotation

  // Outputs an exception without the logical plan.
  // For testing only
  def getSimpleMessage: String = {
    val lineAnnotation = => s" line $l").getOrElse("")
    val positionAnnotation = => s" pos $p").getOrElse("")
Example 91
Source File: SameResultSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.util._

class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation(', ', '
  val testRelation2 = LocalRelation(', ', '

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")

  test("relations") {
    assertSameResult(testRelation, testRelation2)

  test("projections") {
    assertSameResult('a, 'b),'a, 'b))
    assertSameResult('b, 'a),'b, 'a))

    assertSameResult(testRelation,'a), result = false)
    assertSameResult('b, 'a),'a, 'b), result = false)

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))
Example 92
Source File: RuleExecutorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.trees

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}

class RuleExecutorSuite extends SparkFunSuite {
  object DecrementLiterals extends Rule[Expression] {
    def apply(e: Expression): Expression = e transform {
      case IntegerLiteral(i) if i > 0 => Literal(i - 1)

  test("only once") {
    object ApplyOnce extends RuleExecutor[Expression] {
      val batches = Batch("once", Once, DecrementLiterals) :: Nil

    assert(ApplyOnce.execute(Literal(10)) === Literal(9))

  test("to fixed point") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil

    assert(ToFixedPoint.execute(Literal(10)) === Literal(0))

  test("to maxIterations") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil

    val message = intercept[TreeNodeException[LogicalPlan]] {
    assert(message.contains("Max iterations (10) reached for batch fixedPoint"))
Example 93
Source File: ConvertToLocalRelationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal =
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)

Example 94
Source File: CollapseWindowSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
Example 95
Source File: RewriteDistinctAggregatesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Example 96
Source File: OptimizerExtendableSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

  class ExtendedOptimizer extends SimpleTestOptimizer {

    // rules set to DummyRule, would not be executed anyways
    val myBatches: Seq[Batch] = {
      Batch("once", Once,
        DummyRule) ::
      Batch("fixedPoint", FixedPoint(100),
        DummyRule) :: Nil

    override def batches: Seq[Batch] = super.batches ++ myBatches

  test("Extending batches possible") {
    // test simply instantiates the new extended optimizer
    val extendedOptimizer = new ExtendedOptimizer()
Source File: CollapseRepartitionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil

  val testRelation = LocalRelation(', '

  test("collapse two adjacent repartitions into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
Source File: ComputeCurrentTimeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
Source File: ReorderAssociativeOperatorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil

  val testRelation = LocalRelation(', ', '

  test("Reorder associative operators") {
    val originalQuery =
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)

    comparePlans(optimized, correctAnswer)

  test("nested expression with aggregate operator") {
    val originalQuery ="t1")
        .join("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
Source File: AggregateOptimizeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      RemoveRepetitionFromGroupExpressions) :: Nil

  val testRelation = LocalRelation(', ', '

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)

  test("Remove aliased literals") {
    val query ='a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer ='a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("remove repetition in grouping expression") {
    val input = LocalRelation(', ', '
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
Source File: SparkPlanner.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def user: String = sparkContext.sparkUser
  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
      FileSourceStrategy(user) ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  override protected def collectPlaceholders(
     plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
    plan.collect {
      case placeholder @ PlanLater(logicalPlan, _) => placeholder -> logicalPlan

  override protected def prunePlans(plans: Iterator[SparkPlan]): Iterator[SparkPlan] = {
    // TODO: We will need to prune bad plans when we improve plan space exploration
    //       to prevent combinatorial explosion.

  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet( == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]), scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList,, scan)).getOrElse(scan))
Source File: PruneFileSourcePartitions.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule

private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
  override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
    case op @ PhysicalOperation(projects, filters,
        logicalRelation @
          LogicalRelation(fsRelation @
              catalogFileIndex: CatalogFileIndex,
        if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
      // The attribute name of predicate could be different than the one in schema in case of
      // case insensitive, we should change them to match the one in schema, so we donot need to
      // worry about case sensitivity anymore.
      val normalizedFilters = { e =>
        e transform {
          case a: AttributeReference =>

      val sparkSession = fsRelation.sparkSession
      val partitionColumns =
          partitionSchema, sparkSession.sessionState.analyzer.resolver)
      val partitionSet = AttributeSet(partitionColumns)
      val partitionKeyFilters =

      if (partitionKeyFilters.nonEmpty) {
        val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
        val prunedFsRelation =
          fsRelation.copy(location = prunedFileIndex)(sparkSession)
        val prunedLogicalRelation = logicalRelation.copy(
          relation = prunedFsRelation,
          expectedOutputAttributes = Some(logicalRelation.output))

        // Keep partition-pruning predicates so that they are visible in physical planning
        val filterExpression = filters.reduceLeft(And)
        val filter = Filter(filterExpression, prunedLogicalRelation)
        Project(projects, filter)
      } else {
Source File: LogicalRelation.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils

  override def newInstance(): this.type = {

  override def refresh(): Unit = relation match {
    case fs: HadoopFsRelation => fs.location.refresh()
    case _ =>  // Do nothing.

  override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
Example 104
Source File: InsertIntoDataSourceCommand.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OverwriteOptions}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

case class InsertIntoDataSourceCommand(
    logicalRelation: LogicalRelation,
    query: LogicalPlan,
    overwrite: OverwriteOptions)
  extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
    val data = Dataset.ofRows(sparkSession, query)
    // Apply the schema of the existing table to the new data.
    val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
    relation.insert(df, overwrite.enabled)

Example 105
Source File: ddl.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types._

case class CreateTable(
    tableDesc: CatalogTable,
    mode: SaveMode,
    query: Option[LogicalPlan]) extends Command {
  assert(tableDesc.provider.isDefined, "The table to be created must have a provider.")

  if (query.isEmpty) {
      mode == SaveMode.ErrorIfExists || mode == SaveMode.Ignore,
      "create table without data insertion can only use ErrorIfExists or Ignore as SaveMode.")

  override def innerChildren: Seq[QueryPlan[_]] = query.toSeq

case class CreateTempViewUsing(
    tableIdent: TableIdentifier,
    userSpecifiedSchema: Option[StructType],
    replace: Boolean,
    global: Boolean,
    provider: String,
    options: Map[String, String]) extends RunnableCommand {

  if (tableIdent.database.isDefined) {
    throw new AnalysisException(
      s"Temporary view '$tableIdent' should not have specified a database")

  override def argString: String = {
    s"[tableIdent:$tableIdent " + + " ").getOrElse("") +
      s"replace:$replace " +
      s"provider:$provider " +

  def run(sparkSession: SparkSession): Seq[Row] = {
    val dataSource = DataSource(
      userSpecifiedSchema = userSpecifiedSchema,
      className = provider,
      options = options)

    val viewDefinition = Dataset.ofRows(
      sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan

    if (global) {
      catalog.createGlobalTempView(tableIdent.table, viewDefinition, replace)
    } else {
      catalog.createTempView(tableIdent.table, viewDefinition, replace)


case class RefreshTable(tableIdent: TableIdentifier)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    // Refresh the given table's metadata. If this table is cached as an InMemoryRelation,
    // drop the original cached version and make the new version cached lazily.

case class RefreshResource(path: String)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
Source File: cache.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

case class CacheTableCommand(
    tableIdent: TableIdentifier,
    plan: Option[LogicalPlan],
    isLazy: Boolean) extends RunnableCommand {
  require(plan.isEmpty || tableIdent.database.isEmpty,
    "Database name is not allowed in CACHE TABLE AS SELECT")

  override protected def innerChildren: Seq[QueryPlan[_]] = {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    plan.foreach { logicalPlan =>
      Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)

    if (!isLazy) {
      // Performs eager caching


case class UncacheTableCommand(
    tableIdent: TableIdentifier,
    ifExists: Boolean) extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val tableId = tableIdent.quotedString
    try {
    } catch {
      case _: NoSuchTableException if ifExists => // don't throw

case object ClearCacheCommand extends RunnableCommand {

Example 107
Source File: commands.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.streaming.IncrementalExecution
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._

case class ExplainCommand(
    logicalPlan: LogicalPlan,
    override val output: Seq[Attribute] =
      Seq(AttributeReference("plan", StringType, nullable = true)()),
    extended: Boolean = false,
    codegen: Boolean = false)
  extends RunnableCommand {

  // Run through the optimizer to generate the physical plan.
  override def run(sparkSession: SparkSession): Seq[Row] = try {
    val queryExecution =
      if (logicalPlan.isStreaming) {
        // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
        // output mode does not matter since there is no `Sink`.
        new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, 0)
      } else {
    val outputString =
      if (codegen) {
      } else if (extended) {
      } else {
  } catch { case cause: TreeNodeException[_] =>
    ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
Source File: QueryExecutionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext

class QueryExecutionSuite extends SharedSQLContext {
  test("toString() exception/error handling") {
    val badRule = new SparkStrategy {
      var mode: String = ""
      override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match {
        case "exception" => throw new AnalysisException(mode)
        case "error" => throw new Error(mode)
        case _ => Nil
    spark.experimental.extraStrategies = badRule :: Nil

    def qe: QueryExecution = new QueryExecution(spark, OneRowRelation)

    // Nothing!
    badRule.mode = ""

    // Throw an AnalysisException - this should be captured.
    badRule.mode = "exception"

    // Throw an Error - this should not be captured.
    badRule.mode = "error"
    val error = intercept[Error](qe.toString)
Example 109
Source File: SparkPlannerSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil

    var planned = 0
    object TestStrategy extends Strategy {
      def user: String = sparkContext.sparkUser
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child, user) :: planLater(NeverPlanned, user) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec( => planLater(p, user))) :: planLater(NeverPlanned, user) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data, user) :: planLater(NeverPlanned, user) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
Source File: ExtraStrategiesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
        Row("so fast"))

      checkAnswer("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
Source File: CreateTableAsSelect.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn}
import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, HiveMetastoreTypes}

case class CreateTableAsSelect(
    tableDesc: HiveTable,
    query: LogicalPlan,
    allowExisting: Boolean)
  extends RunnableCommand {

  def database: String = tableDesc.database
  def tableName: String =

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val hiveContext = sqlContext.asInstanceOf[HiveContext]
    lazy val metastoreRelation: MetastoreRelation = {
      import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
      import org.apache.hadoop.mapred.TextInputFormat

      val withSchema =
          schema =
              HiveColumn(, HiveMetastoreTypes.toMetastoreType(c.dataType), null)),
          inputFormat =
          outputFormat =
              .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)),
          serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName())))

      // Get the Metastore Relation
      hiveContext.catalog.lookupRelation(Seq(database, tableName), None) match {
        case r: MetastoreRelation => r
    // TODO ideally, we should get the output data ready first and then
    // add the relation into catalog, just in case of failure occurs while data
    // processing.
    if (hiveContext.catalog.tableExists(Seq(database, tableName))) {
      if (allowExisting) {
        // table already exists, will do nothing, to keep consistent with Hive
      } else {
        throw new AnalysisException(s"$database.$tableName already exists.")
    } else {
      hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd


  override def argString: String = {
    s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString
Example 112
Source File: SqlParserSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.Command

private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
  override def output: Seq[Attribute] = Seq.empty
  override def children: Seq[LogicalPlan] = Seq.empty

private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser {

  override protected lazy val start: Parser[LogicalPlan] = set

  private lazy val set: Parser[LogicalPlan] =
    EXECUTE ~> ident ^^ {
      case fileName => TestCommand(fileName)

private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
  protected val EXECUTE = Keyword("EXECUTE")

  override protected lazy val start: Parser[LogicalPlan] = set

  private lazy val set: Parser[LogicalPlan] =
    EXECUTE ~> ident ^^ {
      case fileName => TestCommand(fileName)

class SqlParserSuite extends SparkFunSuite {

  test("test long keyword") {
    val parser = new SuperLongKeywordTestParser
    assert(TestCommand("NotRealCommand") ===
      parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand"))

  test("test case insensitive") {
    val parser = new CaseInsensitiveTestParser
    assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand"))
    assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand"))
    assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand"))
Source File: SameResultSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._

class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation(', ', '
  val testRelation2 = LocalRelation(', ', '

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")

  test("relations") {
    assertSameResult(testRelation, testRelation2)

  test("projections") {
    assertSameResult('a, 'b),'a, 'b))
    assertSameResult('b, 'a),'b, 'a))

    assertSameResult(testRelation,'a), result = false)
    assertSameResult('b, 'a),'a, 'b), result = false)

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
Source File: ConvertToLocalRelationSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation(', ',
      Row(1, 2) ::
      Row(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation(', ',
      Row(1, 3) ::
      Row(4, 6) :: Nil)

    val projectOnLocal =
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)

Example 115
Source File: OptimizeInSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class OptimizeInSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("ConstantFolding", Once,
        OptimizeIn) :: Nil

  val testRelation = LocalRelation(', ', '

  test("OptimizedIn test: In clause optimized to InSet") {
    val originalQuery =
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
        .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))

    comparePlans(optimized, correctAnswer)

  test("OptimizedIn test: In clause not optimized in case filter has attributes") {
    val originalQuery =
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))

    comparePlans(optimized, correctAnswer)
Source File: ProjectCollapsingSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil

  val testRelation = LocalRelation(', '

  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer ='a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),

    comparePlans(optimized, correctAnswer)

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)
Source File: SparkSQLParser.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.types.StringType

private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {

  // A parser for the key-value part of the "SET [key = [value ]]" syntax
  private object SetCommandParser extends RegexParsers {
    private val key: Parser[String] = "(?m)[^=]+".r

    private val value: Parser[String] = "(?m).*$".r

    private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)())

    private val pair: Parser[LogicalPlan] =
      (key ~ ("=".r ~> value).?).? ^^ {
        case None => SetCommand(None, output)
        case Some(k ~ v) => SetCommand(Some(k.trim ->, output)

    def apply(input: String): LogicalPlan = parseAll(pair, input) match {
      case Success(plan, _) => plan
      case x => sys.error(x.toString)

  protected val AS = Keyword("AS")
  protected val CACHE = Keyword("CACHE")
  protected val CLEAR = Keyword("CLEAR")
  protected val IN = Keyword("IN")
  protected val LAZY = Keyword("LAZY")
  protected val SET = Keyword("SET")
  protected val SHOW = Keyword("SHOW")
  protected val TABLE = Keyword("TABLE")
  protected val TABLES = Keyword("TABLES")
  protected val UNCACHE = Keyword("UNCACHE")

  override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others

  private lazy val cache: Parser[LogicalPlan] =
    CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
      case isLazy ~ tableName ~ plan =>
        CacheTableCommand(tableName,, isLazy.isDefined)

  private lazy val uncache: Parser[LogicalPlan] =
    ( UNCACHE ~ TABLE ~> ident ^^ {
        case tableName => UncacheTableCommand(tableName)
    | CLEAR ~ CACHE ^^^ ClearCacheCommand

  private lazy val set: Parser[LogicalPlan] =
    SET ~> restInput ^^ {
      case input => SetCommandParser(input)

  private lazy val show: Parser[LogicalPlan] =
    SHOW ~> TABLES ~ (IN ~> ident).? ^^ {
      case _ ~ dbName => ShowTablesCommand(dbName)

  private lazy val others: Parser[LogicalPlan] =
    wholeInput ^^ {
Example 118
Source File: ExistingRDD.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}

case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext)
   extends LogicalPlan with MultiInstanceRelation {

  override def children: Seq[LogicalPlan] = Nil

  override def newInstance(): this.type =
    LogicalLocalTable(, rows)(sqlContext).asInstanceOf[this.type]

  override def sameResult(plan: LogicalPlan): Boolean = plan match {
    case LogicalRDD(_, otherRDD) => rows == rows
    case _ => false

  @transient override lazy val statistics: Statistics = Statistics(
    // TODO: Improve the statistics estimation.
    // This is made small enough so it can be broadcasted.
    sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1
Source File: FramelessInternals.scala    From frameless   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ObjectType
import scala.reflect.ClassTag

object FramelessInternals {
  def objectTypeFor[A](implicit classTag: ClassTag[A]): ObjectType = ObjectType(classTag.runtimeClass)

  def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = {
    ds.toDF.queryExecution.analyzed.resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver).getOrElse {
      throw new AnalysisException(
        s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames.mkString(", ")})""")

  def expr(column: Column): Expression = column.expr

  def column(column: Column): Expression = column.expr

  def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan

  def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution =

  def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = {
    val joined = executePlan(ds, plan)
    val leftOutput = joined.analyzed.output.take(leftPlan.output.length)
    val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length)

      Alias(CreateStruct(leftOutput), "_1")(),
      Alias(CreateStruct(rightOutput), "_2")()
    ), joined.analyzed)

  def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] =
    new Dataset(sqlContext, plan, encoder)

  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
    Dataset.ofRows(sparkSession, logicalPlan)

  // because org.apache.spark.sql.types.UserDefinedType is private[spark]
  type UserDefinedType[A >: Null] =  org.apache.spark.sql.types.UserDefinedType[A]

  case class DisambiguateRight[T](tagged: Expression) extends Expression with NonSQLExpression {
    def eval(input: InternalRow): Any = tagged.eval(input)
    def nullable: Boolean = false
    def children: Seq[Expression] = tagged :: Nil
    def dataType: DataType = tagged.dataType
    protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ???
Example 120
Source File: AuthzHelper.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up

import org.apache.kyuubi.Logging
import org.apache.spark.KyuubiConf._
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

import yaooqinn.kyuubi.utils.ReflectUtils

private[kyuubi] class AuthzHelper(conf: SparkConf) extends Logging {

  def rule: Seq[Rule[LogicalPlan]] = {
    try {
      val authzMethod = conf.get(AUTHORIZATION_METHOD.key)
      val maybeRule = ReflectUtils.reflectModule(authzMethod, silent = true)
      maybeRule match {
        case Some(authz) if authz.isInstanceOf[Rule[_]] =>
        case _ => Nil
    } catch {
      case _: NoSuchElementException =>
        error(s"${AUTHORIZATION_METHOD.key} is not configured")

private[kyuubi] object AuthzHelper extends Logging {

  private[this] var instance: Option[AuthzHelper] = None

  def get: Option[AuthzHelper] = instance

  def init(conf: SparkConf): Unit = {
    if (conf.get(AUTHORIZATION_ENABLE.key).toBoolean) {
      instance = Some(new AuthzHelper(conf))
      debug("AuthzHelper inited.")
Source File: SparkSQLUtils.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.types.DataType

object SparkSQLUtils {

  def toHiveString(a: (Any, DataType)): String = {

  def getUserJarClassLoader(sparkSession: SparkSession): ClassLoader = {

  def parsePlan(sparkSession: SparkSession, statement: String): LogicalPlan = {

  def toDataFrame(sparkSession: SparkSession, plan: LogicalPlan): DataFrame = {
    Dataset.ofRows(sparkSession, plan)

  def initializeMetaStoreClient(sparkSession: SparkSession): Seq[String] = {
Example 122
Source File: AuthzHelperSuite.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.{KyuubiConf, KyuubiSparkUtil, SparkConf, SparkFunSuite}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

class AuthzHelperSuite extends SparkFunSuite {

  test("test Rule") {

    // NoSuchElementException
    val conf = new SparkConf(loadDefaults = true)
    val authzHelper1 = new AuthzHelper(conf)
    // reflect failure
    val authzHelper2 = new AuthzHelper(conf)

    // success
    conf.set(KyuubiConf.AUTHORIZATION_METHOD.key, "yaooqinn.kyuubi.TestRule")
    val authzHelper3 = new AuthzHelper(conf)

    // type miss match
    conf.set(KyuubiConf.AUTHORIZATION_METHOD.key, "yaooqinn.kyuubi.TestWrongRule")
    val authzHelper4 = new AuthzHelper(conf)

  test("test Init") {
    val conf = new SparkConf(loadDefaults = true)
      .set(KyuubiConf.AUTHORIZATION_METHOD.key, "yaooqinn.kyuubi.TestRule")
      .set(KyuubiConf.AUTHORIZATION_ENABLE.key, "false")

    conf.set(KyuubiConf.AUTHORIZATION_ENABLE.key, "true")
Source File: AddSourceToAttributes.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine.rule

import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.MetadataBuilder
import tech.sourced.engine.{GitRelation, MetadataRelation, Sources}
import tech.sourced.engine.compat

  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case compat.LogicalRelation(rel @ GitRelation(_, _, _, schemaSource),
                                catalogTable) =>
      withMetadata(rel, schemaSource, out, catalogTable)

    case compat.LogicalRelation(
        rel @ MetadataRelation(_, _, _, _, schemaSource),
        catalogTable) =>
      withMetadata(rel, schemaSource, out, catalogTable)

  private def withMetadata(relation: BaseRelation,
                           schemaSource: Option[String],
                           out: Seq[AttributeReference],
                           catalogTable: Option[CatalogTable]): LogicalRelation = {
    val processedOut = schemaSource match {
      case Some(table) =>
        _.withMetadata(new MetadataBuilder().putString(SOURCE, table).build()
      case None => out

    compat.LogicalRelation(relation, processedOut, catalogTable)

Source File: HiveStrategies.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _}
import org.apache.spark.sql.hive.execution._

private[hive] trait HiveStrategies {
  // Possibly being too clever with types here... or not clever enough.
  self: SQLContext#SparkPlanner =>

  val hiveContext: HiveContext

  object Scripts extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) =>
        ScriptTransformation(input, script, output, planLater(child), schema)(hiveContext) :: Nil
      case _ => Nil

  object DataSinks extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.InsertIntoTable(
          table: MetastoreRelation, partition, child, overwrite, ifNotExists) =>
          table, partition, planLater(child), overwrite, ifNotExists) :: Nil
      case hive.InsertIntoHiveTable(
          table: MetastoreRelation, partition, child, overwrite, ifNotExists) =>
          table, partition, planLater(child), overwrite, ifNotExists) :: Nil
      case _ => Nil

  object HiveTableScans extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) =>
        // Filter out all predicates that only deal with partition keys, these are given to the
        // hive table scan operator to be used for partition pruning.
        val partitionKeyIds = AttributeSet(relation.partitionKeys)
        val (pruningPredicates, otherPredicates) = predicates.partition { predicate =>
          !predicate.references.isEmpty &&

          HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
      case _ =>

    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case CreateTableUsing(
        tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) =>
        val cmd =
            tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)
        ExecutedCommand(cmd) :: Nil

      case CreateTableUsingAsSelect(
        tableIdent, provider, false, partitionCols, mode, opts, query) =>
        val cmd =
          CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query)
        ExecutedCommand(cmd) :: Nil

      case _ => Nil

    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case describe: DescribeCommand =>
        val resolvedTable = context.executePlan(describe.table).analyzed
        resolvedTable match {
          case t: MetastoreRelation =>
              DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil

          case o: LogicalPlan =>
            val resultPlan = context.executePlan(o).executedPlan
              resultPlan, describe.output, describe.isExtended)) :: Nil

      case _ => Nil
Source File: SqlParserSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.Command

private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
  override def output: Seq[Attribute] = Seq.empty
  override def children: Seq[LogicalPlan] = Seq.empty

private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser {

  override protected lazy val start: Parser[LogicalPlan] = set

  private lazy val set: Parser[LogicalPlan] =
    EXECUTE ~> ident ^^ {
      case fileName => TestCommand(fileName)

private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
  protected val EXECUTE = Keyword("EXECUTE")

  override protected lazy val start: Parser[LogicalPlan] = set

  private lazy val set: Parser[LogicalPlan] =
    EXECUTE ~> ident ^^ {
      case fileName => TestCommand(fileName)

class SqlParserSuite extends SparkFunSuite {

  test("test long keyword") {
    val parser = new SuperLongKeywordTestParser
    assert(TestCommand("NotRealCommand") ===
      parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand"))

  test("test case insensitive") {
    val parser = new CaseInsensitiveTestParser
    assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand"))
    assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand"))
    assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand"))
Source File: SameResultSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._

class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation(', ', '
  val testRelation2 = LocalRelation(', ', '

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")

  test("relations") {
    assertSameResult(testRelation, testRelation2)

  test("projections") {
    assertSameResult('a, 'b),'a, 'b))
    assertSameResult('b, 'a),'b, 'a))

    assertSameResult(testRelation,'a), result = false)
    assertSameResult('b, 'a),'a, 'b), result = false)

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
Source File: ConvertToLocalRelationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal =
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)

Source File: ColumnPruningSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType

class ColumnPruningSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Column pruning", FixedPoint(100),
      ColumnPruning) :: Nil

  test("Column pruning for Generate when Generate.join = false") {
    val input = LocalRelation(', 'b.array(StringType))

    val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Generate(Explode('b), false, false, None, 's.string :: Nil,
        Project('b.attr :: Nil, input)).analyze

    comparePlans(optimized, correctAnswer)
  //生成Generate.join = true时的列修剪
  test("Column pruning for Generate when Generate.join = true") {
    val input = LocalRelation(', ', 'c.array(StringType))

    val query =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
          Project(Seq('a, 'c),

    comparePlans(optimized, correctAnswer)
  test("Turn Generate.join to false if possible") {
    val input = LocalRelation('b.array(StringType))

    val query =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), true, false, None, 's.string :: Nil,
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), false, false, None, 's.string :: Nil,

    comparePlans(optimized, correctAnswer)

Example 129
Source File: OptimizeInSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class OptimizeInSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("ConstantFolding", Once,
        OptimizeIn) :: Nil

  val testRelation = LocalRelation(', ', '
  test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
    val originalQuery =
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))

    val optimized = Optimize.execute(originalQuery.analyze)
    comparePlans(optimized, originalQuery)
  test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
    val originalQuery =
        .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
        .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))

    comparePlans(optimized, correctAnswer)
  test("OptimizedIn test: In clause not optimized in case filter has attributes") {
    val originalQuery =
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))

    comparePlans(optimized, correctAnswer)
Source File: ProjectCollapsingSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil

  val testRelation = LocalRelation(', '
  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer ='a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),

    comparePlans(optimized, correctAnswer)

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)

    val query = testRelation

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation

    comparePlans(optimized, correctAnswer)

  test("collapse one nondeterministic, one deterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand), 'a)
      .select(('a + 1).as('a_plus_1))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(('a + 1).as('a_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
Source File: AggregateOptimizeSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      RemoveLiteralFromGroupExpressions) :: Nil
  test("replace distinct with aggregate") {
    val input = LocalRelation(', '

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  test("remove literals in grouping expression") {
    val input = LocalRelation(', '

    val query =
      input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(query)

    val correctAnswer = input.groupBy('a)(sum('b))

    comparePlans(optimized, correctAnswer)
Source File: SparkSQLParser.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.types.StringType

private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {

  // A parser for the key-value part of the "SET [key = [value ]]" syntax
  //用于“SET [key = [value]]”语法的键值部分的解析器
  private object SetCommandParser extends RegexParsers {
    private val key: Parser[String] = "(?m)[^=]+".r

    private val value: Parser[String] = "(?m).*$".r

    private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)())

    private val pair: Parser[LogicalPlan] =
      (key ~ ("=".r ~> value).?).? ^^ {
        case None => SetCommand(None)
        case Some(k ~ v) => SetCommand(Some(k.trim ->

    def apply(input: String): LogicalPlan = parseAll(pair, input) match {
      case Success(plan, _) => plan
      case x => sys.error(x.toString)

  protected val AS = Keyword("AS")
  protected val CACHE = Keyword("CACHE")
  protected val CLEAR = Keyword("CLEAR")
  protected val DESCRIBE = Keyword("DESCRIBE")
  protected val EXTENDED = Keyword("EXTENDED")
  protected val FUNCTION = Keyword("FUNCTION")
  protected val FUNCTIONS = Keyword("FUNCTIONS")
  protected val IN = Keyword("IN")
  protected val LAZY = Keyword("LAZY")
  protected val SET = Keyword("SET")
  protected val SHOW = Keyword("SHOW")
  protected val TABLE = Keyword("TABLE")
  protected val TABLES = Keyword("TABLES")
  protected val UNCACHE = Keyword("UNCACHE")

  override protected lazy val start: Parser[LogicalPlan] =
    cache | uncache | set | show | desc | others

  private lazy val cache: Parser[LogicalPlan] =
    CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
      case isLazy ~ tableName ~ plan =>
        CacheTableCommand(tableName,, isLazy.isDefined)

  private lazy val uncache: Parser[LogicalPlan] =
    ( UNCACHE ~ TABLE ~> ident ^^ {
        case tableName => UncacheTableCommand(tableName)
    | CLEAR ~ CACHE ^^^ ClearCacheCommand

  private lazy val set: Parser[LogicalPlan] =
    SET ~> restInput ^^ {
      case input => SetCommandParser(input)

  // It can be the following patterns:
  // SHOW FUNCTIONS mydb.func1;
  // SHOW FUNCTIONS func1;
  // SHOW FUNCTIONS `mydb.a`.`func1.aa`;
  private lazy val show: Parser[LogicalPlan] =
    ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ {
        case _ ~ dbName => ShowTablesCommand(dbName)
    | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ {
        case Some(f) => ShowFunctions(f._1, Some(f._2))
        case None => ShowFunctions(None, None)

  private lazy val desc: Parser[LogicalPlan] =
    DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ {
      case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined)

  private lazy val others: Parser[LogicalPlan] =
    wholeInput ^^ {
      case input => fallback(input)

Example 133
Source File: InsertIntoDataSource.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

private[sql] case class InsertIntoDataSource(
    logicalRelation: LogicalRelation,
    query: LogicalPlan,
    overwrite: Boolean)
  extends RunnableCommand {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
    val data = DataFrame(sqlContext, query)
    // Apply the schema of the existing table to the new data.
    val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
    relation.insert(df, overwrite)

    // Invalidate the cache.

Source File: ExtraStrategiesSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String
case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
  override def children: Seq[SparkPlan] = Nil
object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {//插入一个额外的策略
    try {
      sqlContext.experimental.extraStrategies = TestStrategy :: Nil

      val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
        Row("so fast"))

      checkAnswer("a", "b"),
        Row("so slow", 1))
    } finally {
      sqlContext.experimental.extraStrategies = Nil
Source File: HBaseSparkSession.scala    From Heracles   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hbase

import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.catalog.ExternalCatalog
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.hbase.execution.{HBaseSourceAnalysis, HBaseStrategies}
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SQLConf, SessionState, SharedState}

class HBaseSparkSession(sc: SparkContext) extends SparkSession(sc) {
  self =>

  def this(sparkContext: JavaSparkContext) = this(

  override lazy val sessionState: SessionState = new HBaseSessionStateBuilder(this).build()

    sc.hadoopConfiguration, HBaseConfiguration.create(sc.hadoopConfiguration))

  override lazy val sharedState: SharedState =
    new HBaseSharedState(sc, this.sqlContext)

class HBaseSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session) {
  override lazy val conf: SQLConf = new HBaseSQLConf

  override protected def newBuilder: NewBuilder = new HBaseSessionStateBuilder(_, _)

  override lazy val experimentalMethods: ExperimentalMethods = {
    val result = new ExperimentalMethods;
    result.extraStrategies = Seq((new SparkPlanner(session.sparkContext, conf, new ExperimentalMethods)
      with HBaseStrategies).HBaseDataSource)

  override lazy val analyzer: Analyzer = {
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
          new FindDataSourceTable(session) +:
          new ResolveSQLOnFile(session) +:

      override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
          PreprocessTableCreation(session) +:
          PreprocessTableInsertion(conf) +:
          DataSourceAnalysis(conf) +:
          HBaseSourceAnalysis(session) +:

      override val extendedCheckRules =

class HBaseSharedState(sc: SparkContext, sqlContext: SQLContext) extends SharedState(sc) {
  override lazy val externalCatalog: ExternalCatalog =
    new HBaseCatalog(sqlContext, sc.hadoopConfiguration)
Source File: InsertIntoHiveDirCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.language.existentials

import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.FileUtils
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.mapred._

import org.apache.spark.SparkException
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.hive.client.HiveClientImpl

case class InsertIntoHiveDirCommand(
    isLocal: Boolean,
    storage: CatalogStorageFormat,
    query: LogicalPlan,
    overwrite: Boolean,
    outputColumns: Seq[Attribute]) extends SaveAsHiveFile {

  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {

    val hiveTable = HiveClientImpl.toHiveTable(CatalogTable(
      identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")),
      tableType = org.apache.spark.sql.catalyst.catalog.CatalogTableType.VIEW,
      storage = storage,
      schema = query.schema

    val tableDesc = new TableDesc(

    val hadoopConf = sparkSession.sessionState.newHadoopConf()
    val jobConf = new JobConf(hadoopConf)

    val targetPath = new Path(storage.locationUri.get)
    val writeToPath =
      if (isLocal) {
        val localFileSystem = FileSystem.getLocal(jobConf)
      } else {
        val qualifiedPath = FileUtils.makeQualified(targetPath, hadoopConf)
        val dfs = qualifiedPath.getFileSystem(jobConf)
        if (!dfs.exists(qualifiedPath)) {

    val tmpPath = getExternalTmpPath(sparkSession, hadoopConf, writeToPath)
    val fileSinkConf = new org.apache.spark.sql.hive.HiveShim.ShimFileSinkDesc(
      tmpPath.toString, tableDesc, false)

    try {
        sparkSession = sparkSession,
        plan = child,
        hadoopConf = hadoopConf,
        fileSinkConf = fileSinkConf,
        outputLocation = tmpPath.toString,
        allColumns = outputColumns)

      val fs = writeToPath.getFileSystem(hadoopConf)
      if (overwrite && fs.exists(writeToPath)) {
        fs.listStatus(writeToPath).foreach { existFile =>
          if (Option(existFile.getPath) != createdTempDir) fs.delete(existFile.getPath, true)

      fs.listStatus(tmpPath).foreach {
        tmpFile => fs.rename(tmpFile.getPath, writeToPath)
    } catch {
      case e: Throwable =>
        throw new SparkException(
          "Failed inserting overwrite directory " + storage.locationUri.get, e)
    } finally {

Source File: CreateHiveTableAsSelectCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand

case class CreateHiveTableAsSelectCommand(
    tableDesc: CatalogTable,
    query: LogicalPlan,
    outputColumns: Seq[Attribute],
    mode: SaveMode)
  extends DataWritingCommand {

  private val tableIdentifier = tableDesc.identifier

  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog
    if (catalog.tableExists(tableIdentifier)) {
      assert(mode != SaveMode.Overwrite,
        s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite")

      if (mode == SaveMode.ErrorIfExists) {
        throw new AnalysisException(s"$tableIdentifier already exists.")
      if (mode == SaveMode.Ignore) {
        // Since the table already exists and the save mode is Ignore, we will just return.
        return Seq.empty

        overwrite = false,
        ifPartitionNotExists = false,
        outputColumns = outputColumns).run(sparkSession, child)
    } else {
      // TODO ideally, we should get the output data ready first and then
      // add the relation into catalog, just in case of failure occurs while data
      // processing.
      catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false)

      try {
        // Read back the metadata of the table which was created just now.
        val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier)
        // For CTAS, there is no static partition values to insert.
        val partition = -> None).toMap
          overwrite = true,
          ifPartitionNotExists = false,
          outputColumns = outputColumns).run(sparkSession, child)
      } catch {
        case NonFatal(e) =>
          // drop the created table.
          catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false)
          throw e


  override def argString: String = {
    s"[Database:${tableDesc.database}}, " +
    s"TableName: ${tableDesc.identifier.table}, " +
Source File: HiveSessionStateBuilder.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}

  override protected def planner: SparkPlanner = {
    new SparkPlanner(session.sparkContext, conf, experimentalMethods) with HiveStrategies {
      override val sparkSession: SparkSession = session

      override def extraPlanningStrategies: Seq[Strategy] =
        super.extraPlanningStrategies ++ customPlanningStrategies ++ Seq(HiveTableScans, Scripts)

  override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _)

class HiveSessionResourceLoader(
    session: SparkSession,
    client: HiveClient)
  extends SessionResourceLoader(session) {
  override def addJar(path: String): Unit = {
Source File: PruneFileSourcePartitionsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil

  test("PruneFileSourcePartitions should not change the output of LogicalRelation") {
    withTable("test") {
      withTempDir { dir =>
            |CREATE EXTERNAL TABLE test(i int)
            |PARTITIONED BY (p int)
            |STORED AS parquet
            |LOCATION '${dir.toURI}'""".stripMargin)

        val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
        val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)

        val dataSchema = StructType(tableMeta.schema.filterNot { f =>
        val relation = HadoopFsRelation(
          location = catalogFileIndex,
          partitionSchema = tableMeta.partitionSchema,
          dataSchema = dataSchema,
          bucketSpec = None,
          fileFormat = new ParquetFileFormat(),
          options = Map.empty)(sparkSession = spark)

        val logicalRelation = LogicalRelation(relation, tableMeta)
        val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze

        val optimized = Optimize.execute(query)

  test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") {
    withTable("tbl") {
      spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl")
      val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats
      assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost")

      val df = sql("SELECT * FROM tbl WHERE p = 1")
      val sizes1 = df.queryExecution.analyzed.collect {
        case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes
      assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}")
      assert(sizes1(0) == tableStats.get.sizeInBytes)

      val relations = df.queryExecution.optimizedPlan.collect {
        case relation: LogicalRelation => relation
      assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}")
      val size2 = relations(0).stats.sizeInBytes
      assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes)
      assert(size2 < tableStats.get.sizeInBytes)
Source File: SubstituteUnresolvedOrdinals.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
Source File: ResolveInlineTables.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}

  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = { case (column, name) =>
      val inputTypes =
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      StructField(name, tpe, nullable = column.exists(_.nullable))
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = { row =>
      InternalRow.fromSeq( { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          val castedExpr = if (e.dataType.sameType(targetType)) {
          } else {
            cast(e, targetType)
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)

    LocalRelation(attributes, newRows)
Source File: AnalysisException.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

class AnalysisException protected[sql] (
    val message: String,
    val line: Option[Int] = None,
    val startPosition: Option[Int] = None,
    // Some plans fail to serialize due to bugs in scala collections.
    @transient val plan: Option[LogicalPlan] = None,
    val cause: Option[Throwable] = None)
  extends Exception(message, cause.orNull) with Serializable {

  def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = {
    val newException = new AnalysisException(message, line, startPosition)

  override def getMessage: String = {
    val planAnnotation = Option(plan) => s";\n$p").getOrElse("")
    getSimpleMessage + planAnnotation

  // Outputs an exception without the logical plan.
  // For testing only
  def getSimpleMessage: String = {
    val lineAnnotation = => s" line $l").getOrElse("")
    val positionAnnotation = => s" pos $p").getOrElse("")
Source File: StatsEstimationTestBase.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}

trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen + 8 + 4
    case _ => colStat.avgLen

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
Source File: SameResultSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union}
import org.apache.spark.sql.catalyst.util._

class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation(', ', '
  val testRelation2 = LocalRelation(', ', '

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")

  test("relations") {
    assertSameResult(testRelation, testRelation2)

  test("projections") {
    assertSameResult('a, 'b),'a, 'b))
    assertSameResult('b, 'a),'b, 'a))

    assertSameResult(testRelation,'a), result = false)
    assertSameResult('b, 'a),'a, 'b), result = false)

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))

  test("hint") {
    val df1 = testRelation.join(ResolvedHint(testRelation))
    val df2 = testRelation.join(testRelation)
    assertSameResult(df1, df2)
Source File: RuleExecutorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.trees

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}

class RuleExecutorSuite extends SparkFunSuite {
  object DecrementLiterals extends Rule[Expression] {
    def apply(e: Expression): Expression = e transform {
      case IntegerLiteral(i) if i > 0 => Literal(i - 1)

  test("only once") {
    object ApplyOnce extends RuleExecutor[Expression] {
      val batches = Batch("once", Once, DecrementLiterals) :: Nil

    assert(ApplyOnce.execute(Literal(10)) === Literal(9))

  test("to fixed point") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil

    assert(ToFixedPoint.execute(Literal(10)) === Literal(0))

  test("to maxIterations") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil

    val message = intercept[TreeNodeException[LogicalPlan]] {
    assert(message.contains("Max iterations (10) reached for batch fixedPoint"))

  test("structural integrity checker") {
    object WithSIChecker extends RuleExecutor[Expression] {
      override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
        case IntegerLiteral(_) => true
        case _ => false
      val batches = Batch("once", Once, DecrementLiterals) :: Nil

    assert(WithSIChecker.execute(Literal(10)) === Literal(9))

    val message = intercept[TreeNodeException[LogicalPlan]] {
    assert(message.contains("the structural integrity of the plan is broken"))
Source File: ResolvedUuidExpressionsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}

class ResolvedUuidExpressionsSuite extends AnalysisTest {

  private lazy val a = '
  private lazy val r = LocalRelation(a)
  private lazy val uuid1 = Uuid().as('_uuid1)
  private lazy val uuid2 = Uuid().as('_uuid2)
  private lazy val uuid3 = Uuid().as('_uuid3)
  private lazy val uuid1Ref = uuid1.toAttribute

  private val analyzer = getAnalyzer(caseSensitive = true)

  private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = {
    plan.flatMap {
      case p =>
        p.expressions.flatMap(_.collect {
          case u: Uuid => u

  test("analyzed plan sets random seed for Uuid expression") {
    val plan =, uuid1)
    val resolvedPlan = analyzer.executeAndCheck(plan)
    getUuidExpressions(resolvedPlan).foreach { u =>

  test("Uuid expressions should have different random seeds") {
    val plan =, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
    val resolvedPlan = analyzer.executeAndCheck(plan)
    assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3)

  test("Different analyzed plans should have different random seeds in Uuids") {
    val plan =, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
    val resolvedPlan1 = analyzer.executeAndCheck(plan)
    val resolvedPlan2 = analyzer.executeAndCheck(plan)
    val uuids1 = getUuidExpressions(resolvedPlan1)
    val uuids2 = getUuidExpressions(resolvedPlan2)
    assert(uuids1.distinct.length == 3)
    assert(uuids2.distinct.length == 3)
    assert(uuids1.intersect(uuids2).length == 0)
Source File: ConvertToLocalRelationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal =
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)

Source File: PullupCorrelatedPredicatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("PullupCorrelatedPredicates", Once,
        PullupCorrelatedPredicates) :: Nil

  val testRelation = LocalRelation(', 'b.double)
  val testRelation2 = LocalRelation(', 'd.double)

  test("PullupCorrelatedPredicates should not produce unresolved plan") {
    val correlatedSubquery =
        .where('b < 'd)
    val outerQuery =
        .where(In('a, Seq(ListQuery(correlatedSubquery))))

    val optimized = Optimize.execute(outerQuery)
Example 149
Source File: CheckCartesianProductsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.scalatest.Matchers._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf.CROSS_JOINS_ENABLED

class CheckCartesianProductsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Nil

  val testRelation1 = LocalRelation(', '
  val testRelation2 = LocalRelation(', '

  val joinTypesWithRequiredCondition = Seq(Inner, LeftOuter, RightOuter, FullOuter)
  val joinTypesWithoutRequiredCondition = Seq(LeftSemi, LeftAnti, ExistenceJoin('exists))

  test("CheckCartesianProducts doesn't throw an exception if cross joins are enabled)") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
      noException should be thrownBy {
        for (joinType <- joinTypesWithRequiredCondition ++ joinTypesWithoutRequiredCondition) {

  test("CheckCartesianProducts throws an exception for join types that require a join condition") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithRequiredCondition) {
        val thrownException = the [AnalysisException] thrownBy {
        assert(thrownException.message.contains("Detected implicit cartesian product"))

  test("CheckCartesianProducts doesn't throw an exception if a join condition is present") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithRequiredCondition) {
        noException should be thrownBy {
          performCartesianProductCheck(joinType, Some('a === 'd))

  test("CheckCartesianProducts doesn't throw an exception if join types don't require conditions") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithoutRequiredCondition) {
        noException should be thrownBy {

  private def performCartesianProductCheck(
      joinType: JoinType,
      condition: Option[Expression] = None): Unit = {
    val analyzedPlan = testRelation1.join(testRelation2, joinType, condition).analyze
    val optimizedPlan = Optimize.execute(analyzedPlan)
    comparePlans(analyzedPlan, optimizedPlan)
Source File: EliminateDistinctSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class EliminateDistinctSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Operator Optimizations", Once,
        EliminateDistinct) :: Nil

  val testRelation = LocalRelation('

  test("Eliminate Distinct in Max") {
    val query = testRelation
    val answer = testRelation
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)

  test("Eliminate Distinct in Min") {
    val query = testRelation
    val answer = testRelation
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
Example 151
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf

class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {

  object OptimizeRuleBreakSI extends Rule[LogicalPlan] {
    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
      case Project(projectList, child) =>
        val newAttr = UnresolvedAttribute("unresolvedAttr")
        Project(projectList ++ Seq(newAttr), child)

  object Optimize extends Optimizer(
    new SessionCatalog(
      new InMemoryCatalog,
      new SQLConf())) {
    val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
    override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches

  test("check for invalid plan after execution of rule") {
    val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze
    val message = intercept[TreeNodeException[LogicalPlan]] {
    val ruleName = OptimizeRuleBreakSI.ruleName
    assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
    assert(message.contains("the structural integrity of the plan is broken"))
Source File: CollapseWindowSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)

  test("Don't collapse adjacent windows with dependent columns") {
    val query = testRelation
      .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1)
      .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1)

    val expected = query.analyze
    val optimized = Optimize.execute(query.analyze)
    comparePlans(optimized, expected)
Source File: RewriteDistinctAggregatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, '

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Example 154
Source File: OptimizerExtendableSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

  class ExtendedOptimizer extends SimpleTestOptimizer {

    // rules set to DummyRule, would not be executed anyways
    val myBatches: Seq[Batch] = {
      Batch("once", Once,
        DummyRule) ::
      Batch("fixedPoint", FixedPoint(100),
        DummyRule) :: Nil

    override def batches: Seq[Batch] = super.batches ++ myBatches

  test("Extending batches possible") {
    // test simply instantiates the new extended optimizer
    val extendedOptimizer = new ExtendedOptimizer()
Example 155
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = {
      Batch("EliminateMapObjects", FixedPoint(50),
        EliminateMapObjects) :: Nil

  implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
  implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]()

  test("SPARK-20254: Remove unnecessary data conversion for primitive array") {
    val intObjType = ObjectType(classOf[Array[Int]])
    val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val intQuery = intInput.deserialize[Array[Int]].analyze
    val intOptimized = Optimize.execute(intQuery)
    val intExpected = DeserializeToObject(
      Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
      AttributeReference("obj", intObjType, true)(), intInput)
    comparePlans(intOptimized, intExpected)

    val doubleObjType = ObjectType(classOf[Array[Double]])
    val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
    val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
    val doubleOptimized = Optimize.execute(doubleQuery)
    val doubleExpected = DeserializeToObject(
      Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
      AttributeReference("obj", doubleObjType, true)(), doubleInput)
    comparePlans(doubleOptimized, doubleExpected)
Source File: RewriteSubquerySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.ListQuery
import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class RewriteSubquerySuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Column Pruning", FixedPoint(100), ColumnPruning) ::
      Batch("Rewrite Subquery", FixedPoint(1),
        RemoveRedundantProject) :: Nil

  test("Column pruning after rewriting predicate subquery") {
    val relation = LocalRelation(', '
    val relInSubquery = LocalRelation(', ', '

    val query = relation.where(''x)))).select('a)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = relation
      .join('x), LeftSemi, Some('a === 'x))

    comparePlans(optimized, correctAnswer)

Source File: ComputeCurrentTimeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
Source File: ReorderAssociativeOperatorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil

  val testRelation = LocalRelation(', ', '

  test("Reorder associative operators") {
    val originalQuery =
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)

    comparePlans(optimized, correctAnswer)

  test("nested expression with aggregate operator") {
    val originalQuery ="t1")
        .join("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
Source File: AggregateOptimizeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}

class AggregateOptimizeSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      RemoveRepetitionFromGroupExpressions) :: Nil

  val testRelation = LocalRelation(', ', '

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)

  test("Remove aliased literals") {
    val query ='a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer ='a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("remove repetition in grouping expression") {
    val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
Source File: SparkPlanner.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val experimentalMethods: ExperimentalMethods)
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  override def strategies: Seq[Strategy] =
    experimentalMethods.extraStrategies ++
      extraPlanningStrategies ++ (
      DataSourceV2Strategy ::
      FileSourceStrategy ::
      DataSourceStrategy(conf) ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet( == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]), scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList,, scan)).getOrElse(scan))
Source File: PruneFileSourcePartitions.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.catalog.CatalogStatistics
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule

private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
  override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
    case op @ PhysicalOperation(projects, filters,
        logicalRelation @
          LogicalRelation(fsRelation @
              catalogFileIndex: CatalogFileIndex,
        if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
      // The attribute name of predicate could be different than the one in schema in case of
      // case insensitive, we should change them to match the one in schema, so we donot need to
      // worry about case sensitivity anymore.
      val normalizedFilters = { e =>
        e transform {
          case a: AttributeReference =>

      val sparkSession = fsRelation.sparkSession
      val partitionColumns =
          partitionSchema, sparkSession.sessionState.analyzer.resolver)
      val partitionSet = AttributeSet(partitionColumns)
      val partitionKeyFilters =

      if (partitionKeyFilters.nonEmpty) {
        val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
        val prunedFsRelation =
          fsRelation.copy(location = prunedFileIndex)(sparkSession)
        // Change table stats based on the sizeInBytes of pruned files
        val withStats =
          stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes)))))
        val prunedLogicalRelation = logicalRelation.copy(
          relation = prunedFsRelation, catalogTable = withStats)
        // Keep partition-pruning predicates so that they are visible in physical planning
        val filterExpression = filters.reduceLeft(And)
        val filter = Filter(filterExpression, prunedLogicalRelation)
        Project(projects, filter)
      } else {
Example 162
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils

  override def newInstance(): LogicalRelation = {
    this.copy(output =

  override def refresh(): Unit = relation match {
    case fs: HadoopFsRelation => fs.location.refresh()
    case _ =>  // Do nothing.

  override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"

object LogicalRelation {
  def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation =
    LogicalRelation(relation, relation.schema.toAttributes, None, isStreaming)

  def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation =
    LogicalRelation(relation, relation.schema.toAttributes, Some(table), false)
Source File: InsertIntoDataSourceCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

case class InsertIntoDataSourceCommand(
    logicalRelation: LogicalRelation,
    query: LogicalPlan,
    overwrite: Boolean)
  extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
    val data = Dataset.ofRows(sparkSession, query)
    // Apply the schema of the existing table to the new data.
    val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
    relation.insert(df, overwrite)

    // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
    // data source relation.
    sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)

Source File: ddl.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.util.Locale

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.{DDLUtils, RunnableCommand}
import org.apache.spark.sql.types._

case class CreateTempViewUsing(
    tableIdent: TableIdentifier,
    userSpecifiedSchema: Option[StructType],
    replace: Boolean,
    global: Boolean,
    provider: String,
    options: Map[String, String]) extends RunnableCommand {

  if (tableIdent.database.isDefined) {
    throw new AnalysisException(
      s"Temporary view '$tableIdent' should not have specified a database")

  override def argString: String = {
    s"[tableIdent:$tableIdent " + + " ").getOrElse("") +
      s"replace:$replace " +
      s"provider:$provider " +

  override def run(sparkSession: SparkSession): Seq[Row] = {
    if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
      throw new AnalysisException("Hive data source can only be used with tables, " +
        "you can't use it with CREATE TEMP VIEW USING")

    val dataSource = DataSource(
      userSpecifiedSchema = userSpecifiedSchema,
      className = provider,
      options = options)

    val catalog = sparkSession.sessionState.catalog
    val viewDefinition = Dataset.ofRows(
      sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan

    if (global) {
      catalog.createGlobalTempView(tableIdent.table, viewDefinition, replace)
    } else {
      catalog.createTempView(tableIdent.table, viewDefinition, replace)


case class RefreshTable(tableIdent: TableIdentifier)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    // Refresh the given table's metadata. If this table is cached as an InMemoryRelation,
    // drop the original cached version and make the new version cached lazily.

case class RefreshResource(path: String)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
Source File: SaveIntoDataSourceCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.CreatableRelationProvider

case class SaveIntoDataSourceCommand(
    query: LogicalPlan,
    dataSource: CreatableRelationProvider,
    options: Map[String, String],
    mode: SaveMode) extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
      sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))


Source File: cache.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

case class CacheTableCommand(
    tableIdent: TableIdentifier,
    plan: Option[LogicalPlan],
    isLazy: Boolean) extends RunnableCommand {
  require(plan.isEmpty || tableIdent.database.isEmpty,
    "Database name is not allowed in CACHE TABLE AS SELECT")

  override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq

  override def run(sparkSession: SparkSession): Seq[Row] = {
    plan.foreach { logicalPlan =>
      Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)

    if (!isLazy) {
      // Performs eager caching


case class UncacheTableCommand(
    tableIdent: TableIdentifier,
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val tableId = tableIdent.quotedString
    if (!ifExists || sparkSession.catalog.tableExists(tableId)) {

  override def makeCopy(newArgs: Array[AnyRef]): ClearCacheCommand = ClearCacheCommand()
Source File: InsertIntoDataSourceDirCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources._

case class InsertIntoDataSourceDirCommand(
    storage: CatalogStorageFormat,
    provider: String,
    query: LogicalPlan,
    overwrite: Boolean) extends RunnableCommand {

  override protected def innerChildren: Seq[LogicalPlan] = query :: Nil

  override def run(sparkSession: SparkSession): Seq[Row] = {
    assert(storage.locationUri.nonEmpty, "Directory path is required")
    assert(provider.nonEmpty, "Data source is required")

    // Create the relation based on the input logical plan: `query`.
    val pathOption ="path" -> CatalogUtils.URIToString(_))

    val dataSource = DataSource(
      className = provider,
      options = ++ pathOption,
      catalogTable = None)

    val isFileFormat = classOf[FileFormat].isAssignableFrom(dataSource.providingClass)
    if (!isFileFormat) {
      throw new SparkException(
        "Only Data Sources providing FileFormat are supported: " + dataSource.providingClass)

    val saveMode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists
    try {
      sparkSession.sessionState.executePlan(dataSource.planForWriting(saveMode, query)).toRdd
    } catch {
      case ex: AnalysisException =>
        logError(s"Failed to write to directory " + storage.locationUri.toString, ex)
        throw ex

Source File: QueryExecutionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext

class QueryExecutionSuite extends SharedSQLContext {
  test("toString() exception/error handling") {
    spark.experimental.extraStrategies = Seq(
        new SparkStrategy {
          override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil

    def qe: QueryExecution = new QueryExecution(spark, OneRowRelation())

    // Nothing!

    // Throw an AnalysisException - this should be captured.
    spark.experimental.extraStrategies = Seq(
      new SparkStrategy {
        override def apply(plan: LogicalPlan): Seq[SparkPlan] =
          throw new AnalysisException("exception")

    // Throw an Error - this should not be captured.
    spark.experimental.extraStrategies = Seq(
      new SparkStrategy {
        override def apply(plan: LogicalPlan): Seq[SparkPlan] =
          throw new Error("error")
    val error = intercept[Error](qe.toString)
Example 169
package org.apache.spark.sql.hive

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.hive.execution._

private[hive] trait HiveStrategies {
  // Possibly being too clever with types here... or not clever enough.
  self: SparkPlanner =>

  val sparkSession: SparkSession

  object Scripts extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.ScriptTransformation(input, script, output, child, ioschema) =>
        val hiveIoSchema = HiveScriptIOSchema(ioschema)
        ScriptTransformation(input, script, output, planLater(child), hiveIoSchema) :: Nil
      case _ => Nil

  object DataSinks extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.InsertIntoTable(
          table: MetastoreRelation, partition, child, overwrite, ifNotExists) =>
        InsertIntoHiveTable(table, partition, planLater(child), overwrite, ifNotExists) :: Nil

      case CreateTable(tableDesc, mode, Some(query)) if tableDesc.provider.get == "hive" =>
        val newTableDesc = if ( {
          // add default serde
            serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
        } else {

        // Currently we will never hit this branch, as SQL string API can only use `Ignore` or
        // `ErrorIfExists` mode, and `DataFrameWriter.saveAsTable` doesn't support hive serde
        // tables yet.
        if (mode == SaveMode.Append || mode == SaveMode.Overwrite) {
          throw new AnalysisException(
            "CTAS for hive serde tables does not support append or overwrite semantics.")

        val dbName = tableDesc.identifier.database.getOrElse(sparkSession.catalog.currentDatabase)
        val cmd = CreateHiveTableAsSelectCommand(
          newTableDesc.copy(identifier = tableDesc.identifier.copy(database = Some(dbName))),
          mode == SaveMode.Ignore)
        ExecutedCommandExec(cmd) :: Nil

      case _ => Nil

    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) =>
        // Filter out all predicates that only deal with partition keys, these are given to the
        // hive table scan operator to be used for partition pruning.
        val partitionKeyIds = AttributeSet(relation.partitionKeys)
        val (pruningPredicates, otherPredicates) = predicates.partition { predicate =>
          !predicate.references.isEmpty &&

          HiveTableScanExec(_, relation, pruningPredicates)(sparkSession)) :: Nil
      case _ =>
Example 170
package org.apache.spark.sql.hive.execution

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.hive.MetastoreRelation

case class CreateHiveTableAsSelectCommand(
    tableDesc: CatalogTable,
    query: LogicalPlan,
    ignoreIfExists: Boolean)
  extends RunnableCommand {

  private val tableIdentifier = tableDesc.identifier

  override def innerChildren: Seq[LogicalPlan] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    lazy val metastoreRelation: MetastoreRelation = {
      import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
      import org.apache.hadoop.mapred.TextInputFormat

      val withFormat =
          inputFormat =
          outputFormat =
              .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)),
          serde =[LazySimpleSerDe].getName)),
          compressed =

      val withSchema = if (withFormat.schema.isEmpty) {
        // Hive doesn't support specifying the column list for target table in CTAS
        // However we don't think SparkSQL should follow that.
        tableDesc.copy(schema = query.output.toStructType)
      } else {

      sparkSession.sessionState.catalog.createTable(withSchema, ignoreIfExists = false)

      // Get the Metastore Relation
      sparkSession.sessionState.catalog.lookupRelation(tableIdentifier) match {
        case r: MetastoreRelation => r
    // TODO ideally, we should get the output data ready first and then
    // add the relation into catalog, just in case of failure occurs while data
    // processing.
    if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) {
      if (ignoreIfExists) {
        // table already exists, will do nothing, to keep consistent with Hive
      } else {
        throw new AnalysisException(s"$tableIdentifier already exists.")
    } else {
      try {
          metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd
      } catch {
        case NonFatal(e) =>
          // drop the created table.
          sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true,
            purge = false)
          throw e


  override def argString: String = {
    s"[Database:${tableDesc.database}}, " +
    s"TableName: ${tableDesc.identifier.table}, " +
Source File: PruneFileSourcePartitionsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions, TableFileCatalog}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil

  test("PruneFileSourcePartitions should not change the output of LogicalRelation") {
    withTable("test") {
      withTempDir { dir =>
            |CREATE EXTERNAL TABLE test(i int)
            |PARTITIONED BY (p int)
            |STORED AS parquet
            |LOCATION '${dir.getAbsolutePath}'""".stripMargin)

        val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
        val tableFileCatalog = new TableFileCatalog(spark, tableMeta, 0)

        val dataSchema = StructType(tableMeta.schema.filterNot { f =>
        val relation = HadoopFsRelation(
          location = tableFileCatalog,
          partitionSchema = tableMeta.partitionSchema,
          dataSchema = dataSchema,
          bucketSpec = None,
          fileFormat = new ParquetFileFormat(),
          options = Map.empty)(sparkSession = spark)

        val logicalRelation = LogicalRelation(relation, catalogTable = Some(tableMeta))
        val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze

        val optimized = Optimize.execute(query)
Source File: SQLBuilderTest.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import scala.util.control.NonFatal

import org.apache.spark.sql.{DataFrame, Dataset, QueryTest}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.hive.test.TestHiveSingleton

abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
  protected def checkSQL(e: Expression, expectedSQL: String): Unit = {
    val actualSQL = e.sql
    try {
      assert(actualSQL === expectedSQL)
    } catch {
      case cause: Throwable =>
          s"""Wrong SQL generated for the following expression:

  protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = {
    val generatedSQL = try new SQLBuilder(plan).toSQL catch { case NonFatal(e) =>
        s"""Cannot convert the following logical query plan to SQL:

    try {
      assert(generatedSQL === expectedSQL)
    } catch {
      case cause: Throwable =>
          s"""Wrong SQL generated for the following logical query plan:

    checkAnswer(spark.sql(generatedSQL), Dataset.ofRows(spark, plan))

  protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
    checkSQL(df.queryExecution.analyzed, expectedSQL)
Source File: SubstituteUnresolvedOrdinals.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
Source File: ResolveTableValuedFunctions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}

      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
          "numPartitions" -> IntegerType) {
        case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
          Range(start, end, step, Some(numPartitions))

  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
      builtinFunctions.get(u.functionName) match {
        case Some(tvf) =>
          val resolved = tvf.flatMap { case (argList, resolver) =>
            argList.implicitCast(u.functionArgs) match {
              case Some(casted) =>
              case _ =>
          resolved.headOption.getOrElse {
            val argTypes =", ")
              s"""error: table-valued function ${u.functionName} with alternatives:
                |${ => s" ($x)").mkString("\n")}
                |cannot be applied to: (${argTypes})""".stripMargin)
        case _ =>
          u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
Source File: ResolveInlineTables.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}

  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = { case (column, name) =>
      val inputTypes =
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      StructField(name, tpe, nullable = column.exists(_.nullable))
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = { row =>
      InternalRow.fromSeq( { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
          } else {
            Cast(e, targetType).eval()
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")

    LocalRelation(attributes, newRows)
Source File: AnalysisException.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

class AnalysisException protected[sql] (
    val message: String,
    val line: Option[Int] = None,
    val startPosition: Option[Int] = None,
    val plan: Option[LogicalPlan] = None,
    val cause: Option[Throwable] = None)
  extends Exception(message, cause.orNull) with Serializable {

  def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = {
    val newException = new AnalysisException(message, line, startPosition)

  override def getMessage: String = {
    val planAnnotation = => s";\n$p").getOrElse("")
    getSimpleMessage + planAnnotation

  // Outputs an exception without the logical plan.
  // For testing only
  def getSimpleMessage: String = {
    val lineAnnotation = => s" line $l").getOrElse("")
    val positionAnnotation = => s" pos $p").getOrElse("")
Source File: SameResultSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.util._

class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation(', ', '
  val testRelation2 = LocalRelation(', ', '

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")

  test("relations") {
    assertSameResult(testRelation, testRelation2)

  test("projections") {
    assertSameResult('a, 'b),'a, 'b))
    assertSameResult('b, 'a),'b, 'a))

    assertSameResult(testRelation,'a), result = false)
    assertSameResult('b, 'a),'a, 'b), result = false)

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))
Source File: RuleExecutorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.trees

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}

class RuleExecutorSuite extends SparkFunSuite {
  object DecrementLiterals extends Rule[Expression] {
    def apply(e: Expression): Expression = e transform {
      case IntegerLiteral(i) if i > 0 => Literal(i - 1)

  test("only once") {
    object ApplyOnce extends RuleExecutor[Expression] {
      val batches = Batch("once", Once, DecrementLiterals) :: Nil

    assert(ApplyOnce.execute(Literal(10)) === Literal(9))

  test("to fixed point") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil

    assert(ToFixedPoint.execute(Literal(10)) === Literal(0))

  test("to maxIterations") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil

    val message = intercept[TreeNodeException[LogicalPlan]] {
    assert(message.contains("Max iterations (10) reached for batch fixedPoint"))
Source File: ConvertToLocalRelationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation(', ',
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal =
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)

Source File: CollapseWindowSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.window(Seq(
        min(a).as('min_a)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
Source File: RewriteDistinctAggregatesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, '

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Example 182
Source File: OptimizerExtendableSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

  class ExtendedOptimizer extends SimpleTestOptimizer {

    // rules set to DummyRule, would not be executed anyways
    val myBatches: Seq[Batch] = {
      Batch("once", Once,
        DummyRule) ::
      Batch("fixedPoint", FixedPoint(100),
        DummyRule) :: Nil

    override def batches: Seq[Batch] = super.batches ++ myBatches

  test("Extending batches possible") {
    // test simply instantiates the new extended optimizer
    val extendedOptimizer = new ExtendedOptimizer()
Source File: CollapseRepartitionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil

  val testRelation = LocalRelation(', '

  test("collapse two adjacent repartitions into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
Source File: ComputeCurrentTimeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
Source File: ReorderAssociativeOperatorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil

  val testRelation = LocalRelation(', ', '

  test("Reorder associative operators") {
    val originalQuery =
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)

    comparePlans(optimized, correctAnswer)

  test("nested expression with aggregate operator") {
    val originalQuery ="t1")
        .join("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
Source File: AggregateOptimizeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      RemoveRepetitionFromGroupExpressions) :: Nil

  val testRelation = LocalRelation(', ', '

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)

  test("Remove aliased literals") {
    val query ='a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer ='a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)

  test("remove repetition in grouping expression") {
    val input = LocalRelation(', ', '
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
Source File: SparkPlanner.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
      FileSourceStrategy ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
    plan.collect {
      case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan

  override protected def prunePlans(plans: Iterator[SparkPlan]): Iterator[SparkPlan] = {
    // TODO: We will need to prune bad plans when we improve plan space exploration
    //       to prevent combinatorial explosion.

  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet( == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]), scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList,, scan)).getOrElse(scan))
Source File: PruneFileSourcePartitions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule

private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
  override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
    case op @ PhysicalOperation(projects, filters,
        logicalRelation @
          LogicalRelation(fsRelation @
              tableFileCatalog: TableFileCatalog,
        if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
      // The attribute name of predicate could be different than the one in schema in case of
      // case insensitive, we should change them to match the one in schema, so we donot need to
      // worry about case sensitivity anymore.
      val normalizedFilters = { e =>
        e transform {
          case a: AttributeReference =>

      val sparkSession = fsRelation.sparkSession
      val partitionColumns =
          partitionSchema, sparkSession.sessionState.analyzer.resolver)
      val partitionSet = AttributeSet(partitionColumns)
      val partitionKeyFilters =

      if (partitionKeyFilters.nonEmpty) {
        val prunedFileCatalog = tableFileCatalog.filterPartitions(partitionKeyFilters.toSeq)
        val prunedFsRelation =
          fsRelation.copy(location = prunedFileCatalog)(sparkSession)
        val prunedLogicalRelation = logicalRelation.copy(
          relation = prunedFsRelation,
          expectedOutputAttributes = Some(logicalRelation.output))

        // Keep partition-pruning predicates so that they are visible in physical planning
        val filterExpression = filters.reduceLeft(And)
        val filter = Filter(filterExpression, prunedLogicalRelation)
        Project(projects, filter)
      } else {
Source File: LogicalRelation.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils

  override def newInstance(): this.type = {

  override def refresh(): Unit = relation match {
    case fs: HadoopFsRelation => fs.location.refresh()
    case _ =>  // Do nothing.

  override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
Source File: InsertIntoDataSourceCommand.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

case class InsertIntoDataSourceCommand(
    logicalRelation: LogicalRelation,
    query: LogicalPlan,
    overwrite: Boolean)
  extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
    val data = Dataset.ofRows(sparkSession, query)
    // Apply the schema of the existing table to the new data.
    val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
    relation.insert(df, overwrite)

    // Invalidate the cache.

Source File: InsertIntoHadoopFsRelationCommand.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources


import org.apache.hadoop.fs.Path

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand

)) {
          throw new IOException(s"Unable to clear output " +
            s"directory $qualifiedOutputPath prior to writing to it")
      case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
      case (SaveMode.Ignore, exists) =>
      case (s, exists) =>
        throw new IllegalStateException(s"unsupported save mode $s ($exists)")
    // If we are appending data to an existing dir.
    val isAppend = pathExists && (mode == SaveMode.Append)

    if (doInsertion) {
    } else {
      logInfo("Skipping insertion into a relation that already exists.")

Source File: ddl.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types._

case class CreateTable(
    tableDesc: CatalogTable,
    mode: SaveMode,
    query: Option[LogicalPlan]) extends Command {
  assert(tableDesc.provider.isDefined, "The table to be created must have a provider.")

  if (query.isEmpty) {
      mode == SaveMode.ErrorIfExists || mode == SaveMode.Ignore,
      "create table without data insertion can only use ErrorIfExists or Ignore as SaveMode.")

  override def innerChildren: Seq[QueryPlan[_]] = query.toSeq

class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
  with Serializable {

  val baseMap = => kv.copy(_1 = kv._1.toLowerCase))

  override def get(k: String): Option[String] = baseMap.get(k.toLowerCase)

  override def + [B1 >: String](kv: (String, B1)): Map[String, B1] =
    baseMap + kv.copy(_1 = kv._1.toLowerCase)

  override def iterator: Iterator[(String, String)] = baseMap.iterator

  override def -(key: String): Map[String, String] = baseMap - key.toLowerCase
Source File: cache.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

case class CacheTableCommand(
    tableIdent: TableIdentifier,
    plan: Option[LogicalPlan],
    isLazy: Boolean) extends RunnableCommand {
  require(plan.isEmpty || tableIdent.database.isEmpty,
    "Database name is not allowed in CACHE TABLE AS SELECT")

  override protected def innerChildren: Seq[QueryPlan[_]] = {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    plan.foreach { logicalPlan =>
      Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)

    if (!isLazy) {
      // Performs eager caching


case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {

case object ClearCacheCommand extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
Source File: commands.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.streaming.IncrementalExecution
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._

case class ExplainCommand(
    logicalPlan: LogicalPlan,
    override val output: Seq[Attribute] =
      Seq(AttributeReference("plan", StringType, nullable = true)()),
    extended: Boolean = false,
    codegen: Boolean = false)
  extends RunnableCommand {

  // Run through the optimizer to generate the physical plan.
  override def run(sparkSession: SparkSession): Seq[Row] = try {
    val queryExecution =
      if (logicalPlan.isStreaming) {
        // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
        // output mode does not matter since there is no `Sink`.
        new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0)
      } else {
    val outputString =
      if (codegen) {
      } else if (extended) {
      } else {
  } catch { case cause: TreeNodeException[_] =>
    ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
Source File: SparkPlannerSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec( :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
Source File: ExtraStrategiesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
        Row("so fast"))

      checkAnswer("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
Source File: MatfastSessionState.scala    From MatRel   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.matfast

import org.apache.spark.sql.matfast.execution.MatfastPlanner
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.{execution => sparkexecution, _}
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.matfast.plans.{MatfastOptimizer}

import scala.collection.immutable

private[matfast] class MatfastSessionState (matfastSession: MatfastSession)
  extends SessionState(matfastSession) {
  self =>

  protected[matfast] lazy val matfastConf = new MatfastConf

  protected[matfast] def getSQLOptimizer = optimizer

  protected[matfast] lazy val matfastOptimizer: MatfastOptimizer = new MatfastOptimizer

  protected[matfast] val matfastPlanner: sparkexecution.SparkPlanner = {
    new MatfastPlanner(matfastSession, conf, experimentalMethods.extraStrategies)

  override def executePlan(plan: LogicalPlan) =
    new execution.QueryExecution(matfastSession, plan)

  def setConf(key: String, value: String): Unit = {
    if (key.startsWith("matfast.")) matfastConf.setConfString(key, value)
    else conf.setConfString(key, value)

  def getConf(key: String): String = {
    if (key.startsWith("matfast.")) matfastConf.getConfString(key)
    else conf.getConfString(key)

  def getConf(key: String, defaultValue: String): String = {
    if (key.startsWith("matfast.")) conf.getConfString(key, defaultValue)
    else conf.getConfString(key, defaultValue)

  def getAllConfs: immutable.Map[String, String] = {
    conf.getAllConfs ++ matfastConf.getAllConfs
Source File: QueryExecution.scala    From MatRel   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.matfast.execution

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{QueryExecution => SQLQueryExecution, SparkPlan}
import org.apache.spark.sql.matfast.MatfastSession

class QueryExecution(val matfastSession: MatfastSession, val matfastLogical: LogicalPlan)
  extends SQLQueryExecution(matfastSession, matfastLogical) {

  lazy val matrixData: LogicalPlan = {

  override lazy val optimizedPlan: LogicalPlan = {

  override lazy val sparkPlan: SparkPlan = {
Source File: SparkWrapper.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark

import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.types.{DataType, Metadata}

object SparkWrapper {
  def getVersion: String = {

  def newSubqueryAlias(identifier: String, child: LogicalPlan): SubqueryAlias = {
    SubqueryAlias(identifier, child)

  def newAlias(child: Expression, name: String): Alias = {
    Alias(child, name)()

  def newAttributeReference(
      name: String,
      dataType: DataType,
      nullable: Boolean,
      metadata: Metadata): AttributeReference = {
    AttributeReference(name, dataType, nullable, metadata)()

  def callSessionCatalogCreateTable(
      obj: SessionCatalog,
      tableDefinition: CatalogTable,
      ignoreIfExists: Boolean): Unit = {
    obj.createTable(tableDefinition, ignoreIfExists)
Source File: SparkWrapper.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark

import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.types.{DataType, Metadata}

object SparkWrapper {
  def getVersion: String = {

  def newSubqueryAlias(identifier: String, child: LogicalPlan): SubqueryAlias = {
    SubqueryAlias(identifier, child)

  def newAlias(child: Expression, name: String): Alias = {
    Alias(child, name)()

  def newAttributeReference(
      name: String,
      dataType: DataType,
      nullable: Boolean,
      metadata: Metadata): AttributeReference = {
    AttributeReference(name, dataType, nullable, metadata)()

  def callSessionCatalogCreateTable(
      obj: SessionCatalog,
      tableDefinition: CatalogTable,
      ignoreIfExists: Boolean): Unit = {
    obj.createTable(tableDefinition, ignoreIfExists)