org.apache.spark.sql.catalyst.expressions.NamedExpression Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.NamedExpression. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TiAggregation.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import com.pingcap.tispark.TiDBRelation
import com.pingcap.tispark.utils.ReflectionUtil
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources.LogicalRelation

object TiAggregation {
  type ReturnType =
    (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)

  def unapply(plan: LogicalPlan): Option[ReturnType] =
    ReflectionUtil.callTiAggregationImplUnapply(plan)
}

object TiAggregationProjection {
  type ReturnType = (Seq[Expression], LogicalPlan, TiDBRelation, Seq[NamedExpression])

  def unapply(plan: LogicalPlan): Option[ReturnType] =
    plan match {
      // Only push down aggregates projection when all filters can be applied and
      // all projection expressions are column references
      case PhysicalOperation(
            projects,
            filters,
            rel @ LogicalRelation(source: TiDBRelation, _, _, _))
          if projects.forall(_.isInstanceOf[Attribute]) =>
        Some((filters, rel, source, projects))
      case _ => Option.empty[ReturnType]
    }
} 
Example 2
Source File: OptimizeHiveMetadataOnlyQuerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfter

import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton
    with BeforeAndAfter with SQLTestUtils {

  import spark.implicits._

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)")
    (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)"))
  }

  override protected def afterAll(): Unit = {
    try {
      sql("DROP TABLE IF EXISTS metadata_only")
    } finally {
      super.afterAll()
    }
  }

  test("SPARK-23877: validate metadata-only query pushes filters to metastore") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the number of matching partitions
      assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5)

      // verify that the partition predicate was pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5)
    }
  }

  test("SPARK-23877: filter on projected expression") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the matching partitions
      val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr,
        Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]),
          spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child)))
          .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType))))

      checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x"))

      // verify that the partition predicate was not pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11)
    }
  }
} 
Example 3
Source File: LogicalPlanSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("transformUp runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan transformUp function

    assert(invocationCount === 1)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 1)
  }

  test("transformUp runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan transformUp function

    assert(invocationCount === 2)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 2)
  }

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true)

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output
    }

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
  }

  test("transformExpressions works with a Stream") {
    val id1 = NamedExpression.newExprId
    val id2 = NamedExpression.newExprId
    val plan = Project(Stream(
      Alias(Literal(1), "a")(exprId = id1),
      Alias(Literal(2), "b")(exprId = id2)),
      OneRowRelation())
    val result = plan.transformExpressions {
      case Literal(v: Int, IntegerType) if v != 1 =>
        Literal(v + 1, IntegerType)
    }
    val expected = Project(Stream(
      Alias(Literal(1), "a")(exprId = id1),
      Alias(Literal(3), "b")(exprId = id2)),
      OneRowRelation())
    assert(result.sameResult(expected))
  }
} 
Example 4
Source File: QueryPlanSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.types.IntegerType

class QueryPlanSuite extends SparkFunSuite {

  test("origin remains the same after mapExpressions (SPARK-23823)") {
    CurrentOrigin.setPosition(0, 0)
    val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId)
    val query = plans.DslLogicalPlan(plans.table("table")).select(column)
    CurrentOrigin.reset()

    val mappedQuery = query mapExpressions {
      case _: Expression => Literal(1)
    }

    val mappedOrigin = mappedQuery.expressions.apply(0).origin
    assert(mappedOrigin == Origin.apply(Some(0), Some(0)))
  }

} 
Example 5
Source File: OapAggUtils.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, Partial}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.oap.OapAggregationFileScanExec

object OapAggUtils {
  private def createAggregate(
      requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
      groupingExpressions: Seq[NamedExpression] = Nil,
      aggregateExpressions: Seq[AggregateExpression] = Nil,
      aggregateAttributes: Seq[Attribute] = Nil,
      initialInputBufferOffset: Int = 0,
      resultExpressions: Seq[NamedExpression] = Nil,
      child: SparkPlan): SparkPlan = {
    if (requiredChildDistributionExpressions.isDefined) {
      // final aggregate, fall back to Spark HashAggregateExec.
      HashAggregateExec(
        requiredChildDistributionExpressions = requiredChildDistributionExpressions,
        groupingExpressions = groupingExpressions,
        aggregateExpressions = aggregateExpressions,
        aggregateAttributes = aggregateAttributes,
        initialInputBufferOffset = initialInputBufferOffset,
        resultExpressions = resultExpressions,
        child = child)
    } else {
      // Apply partial aggregate optimizations.
      OapAggregateExec(
        requiredChildDistributionExpressions = None,
        groupingExpressions = groupingExpressions,
        aggregateExpressions = aggregateExpressions,
        aggregateAttributes = aggregateAttributes,
        initialInputBufferOffset = initialInputBufferOffset,
        resultExpressions = resultExpressions,
        child = child)
    }
  }

  def planAggregateWithoutDistinct(
      groupingExpressions: Seq[NamedExpression],
      aggregateExpressions: Seq[AggregateExpression],
      resultExpressions: Seq[NamedExpression],
      child: SparkPlan): Seq[SparkPlan] = {
    val useHash = HashAggregateExec.supportsAggregate(
      aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))

    if (!child.isInstanceOf[OapAggregationFileScanExec] || !useHash) {
      // Child can not leverage oap optimization reading.
      Nil
    } else {
      // 1. Create an Aggregate Operator for partial aggregations.
      val groupingAttributes = groupingExpressions.map(_.toAttribute)
      val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
      val partialAggregateAttributes =
        partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
      val partialResultExpressions =
        groupingAttributes ++
          partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)

      val partialAggregate = createAggregate(
        requiredChildDistributionExpressions = None,
        groupingExpressions = groupingExpressions,
        aggregateExpressions = partialAggregateExpressions,
        aggregateAttributes = partialAggregateAttributes,
        initialInputBufferOffset = 0,
        resultExpressions = partialResultExpressions,
        child = child)

      // 2. Create an Aggregate Operator for final aggregations.
      val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
      // The attributes of the final aggregation buffer, which is presented as input to the result
      // projection:
      val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)

      val finalAggregate = createAggregate(
        requiredChildDistributionExpressions = Some(groupingAttributes),
        groupingExpressions = groupingAttributes,
        aggregateExpressions = finalAggregateExpressions,
        aggregateAttributes = finalAggregateAttributes,
        initialInputBufferOffset = groupingExpressions.length,
        resultExpressions = resultExpressions,
        child = partialAggregate)

      finalAggregate :: Nil
    }
  }
} 
Example 6
Source File: package.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.mv

import scala.language.implicitConversions

import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import org.apache.carbondata.mv.plans._
import org.apache.carbondata.mv.plans.modular.{JoinEdge, ModularPlan}
import org.apache.carbondata.mv.plans.modular.Flags._
import org.apache.carbondata.mv.plans.util._


package object dsl {

  object Plans {

    implicit class DslModularPlan(val modularPlan: ModularPlan) {
      def select(outputExprs: NamedExpression*)
        (inputExprs: Expression*)
        (predicateExprs: Expression*)
        (aliasMap: Map[Int, String])
        (joinEdges: JoinEdge*): ModularPlan = {
        modular
          .Select(
            outputExprs,
            inputExprs,
            predicateExprs,
            aliasMap,
            joinEdges,
            Seq(modularPlan),
            NoFlags,
            Seq.empty,
            Seq.empty)
      }

      def groupBy(outputExprs: NamedExpression*)
        (inputExprs: Expression*)
        (predicateExprs: Expression*): ModularPlan = {
        modular
          .GroupBy(outputExprs, inputExprs, predicateExprs, None, modularPlan, NoFlags, Seq.empty)
      }

      def harmonize: ModularPlan = modularPlan.harmonized
    }

    implicit class DslModularPlans(val modularPlans: Seq[ModularPlan]) {
      def select(outputExprs: NamedExpression*)
        (inputExprs: Expression*)
        (predicateList: Expression*)
        (aliasMap: Map[Int, String])
        (joinEdges: JoinEdge*): ModularPlan = {
        modular
          .Select(
            outputExprs,
            inputExprs,
            predicateList,
            aliasMap,
            joinEdges,
            modularPlans,
            NoFlags,
            Seq.empty,
            Seq.empty)
      }

      def union(): ModularPlan = modular.Union(modularPlans, NoFlags, Seq.empty)
    }

    implicit class DslLogical2Modular(val logicalPlan: LogicalPlan) {
      def resolveonly: LogicalPlan = analysis.SimpleAnalyzer.execute(logicalPlan)

      def modularize: ModularPlan = modular.SimpleModularizer.modularize(logicalPlan).next

      def optimize: LogicalPlan = BirdcageOptimizer.execute(logicalPlan)
    }

   }

} 
Example 7
Source File: ExpressionHelper.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.mv.plans.modular

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId, Expression, NamedExpression}
import org.apache.spark.sql.types.{DataType, Metadata}

object ExpressionHelper {

  def createReference(
     name: String,
     dataType: DataType,
     nullable: Boolean,
     metadata: Metadata,
     exprId: ExprId,
     qualifier: Option[String],
     attrRef : NamedExpression = null): AttributeReference = {
    AttributeReference(name, dataType, nullable, metadata)(exprId, qualifier)
  }

  def createAlias(
       child: Expression,
       name: String,
       exprId: ExprId = NamedExpression.newExprId,
       qualifier: Option[String] = None,
       explicitMetadata: Option[Metadata] = None,
       namedExpr : Option[NamedExpression] = None ) : Alias = {
    Alias(child, name)(exprId, qualifier, explicitMetadata)
  }

  def getTheLastQualifier(reference: AttributeReference): String = {
    reference.qualifier.head
  }

} 
Example 8
Source File: ExpressionHelper.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.mv.plans.modular

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId, Expression, NamedExpression}
import org.apache.spark.sql.types.{DataType, Metadata}

object ExpressionHelper {

  def createReference(
     name: String,
     dataType: DataType,
     nullable: Boolean,
     metadata: Metadata,
     exprId: ExprId,
     qualifier: Option[String],
     attrRef : NamedExpression = null): AttributeReference = {
    val qf = if (qualifier.nonEmpty) Seq(qualifier.get) else Seq.empty
    AttributeReference(name, dataType, nullable, metadata)(exprId, qf)
  }

  def createAlias(
      child: Expression,
      name: String,
      exprId: ExprId,
      qualifier: Option[String]) : Alias = {
    val qf = if (qualifier.nonEmpty) Seq(qualifier.get) else Seq.empty
    Alias(child, name)(exprId, qf, None)
  }

  def getTheLastQualifier(reference: AttributeReference): String = {
    reference.qualifier.reverse.head
  }

} 
Example 9
Source File: CarbonIUDRule.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.optimizer

import org.apache.spark.sql.ProjectForUpdate
import org.apache.spark.sql.catalyst.expressions.{NamedExpression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.mutation.CarbonProjectForUpdateCommand

import org.apache.carbondata.core.constants.CarbonCommonConstants


class CarbonIUDRule extends Rule[LogicalPlan] with PredicateHelper {
  override def apply(plan: LogicalPlan): LogicalPlan = {
      processPlan(plan)
  }

  private def processPlan(plan: LogicalPlan): LogicalPlan = {
    plan transform {
      case ProjectForUpdate(table, cols, Seq(updatePlan)) =>
        var isTransformed = false
        val newPlan = updatePlan transform {
          case Project(pList, child) if !isTransformed =>
            var (dest: Seq[NamedExpression], source: Seq[NamedExpression]) = pList
              .splitAt(pList.size - cols.size)
            // check complex column
            cols.foreach { col =>
              val complexExists = "\"name\":\"" + col + "\""
              if (dest.exists(m => m.dataType.json.contains(complexExists))) {
                throw new UnsupportedOperationException(
                  "Unsupported operation on Complex data type")
              }
            }
            // check updated columns exists in table
            val diff = cols.diff(dest.map(_.name.toLowerCase))
            if (diff.nonEmpty) {
              sys.error(s"Unknown column(s) ${ diff.mkString(",") } in table ${ table.tableName }")
            }
            // modify plan for updated column *in place*
            isTransformed = true
            source.foreach { col =>
              val colName = col.name.substring(0,
                col.name.lastIndexOf(CarbonCommonConstants.UPDATED_COL_EXTENSION))
              val updateIdx = dest.indexWhere(_.name.equalsIgnoreCase(colName))
              dest = dest.updated(updateIdx, col)
            }
            Project(dest, child)
        }
        CarbonProjectForUpdateCommand(
          newPlan, table.tableIdentifier.database, table.tableIdentifier.table, cols)
    }
  }
} 
Example 10
Source File: DruidOperatorSchema.scala    From spark-druid-olap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.druid

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.types.DataType
import org.sparklinedata.druid.{DruidOperatorAttribute, DruidQueryBuilder}


  lazy val pushedDownExprToDruidAttr : Map[Expression, DruidOperatorAttribute] =
    buildPushDownDruidAttrsMap

  private def pushDownExpressionMap : Map[String, (Expression, DataType, DataType, String)] =
    dqb.outputAttributeMap.filter(t => t._2._1 != null)

  private def buildPushDownDruidAttrsMap : Map[Expression, DruidOperatorAttribute] =
    (pushDownExpressionMap map {
    case (nm, (e, oDT, dDT, tf)) => {
      (e -> druidAttrMap(nm))
    }
  })


  private def buildDruidOpAttr : Map[String, DruidOperatorAttribute] =
    (dqb.outputAttributeMap map {
      case (nm, (e, oDT, dDT, tf)) => {
        val druidEid = e match {
          case null => NamedExpression.newExprId
          case n: NamedExpression => n.exprId
          case _ => NamedExpression.newExprId
        }
        (nm -> DruidOperatorAttribute(druidEid, nm, dDT, tf))
      }
    }
      )

} 
Example 11
Source File: PostAggregate.scala    From spark-druid-olap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.druid

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
import org.sparklinedata.druid._

class PostAggregate(val druidOpSchema : DruidOperatorSchema) {

  val dqb = druidOpSchema.dqb

  private def attrRef(dOpAttr : DruidOperatorAttribute) : AttributeReference =
    AttributeReference(dOpAttr.name, dOpAttr.dataType)(dOpAttr.exprId)

  lazy val groupExpressions = dqb.dimensions.map { d =>
    attrRef(druidOpSchema.druidAttrMap(d.outputName))

  }

  def namedGroupingExpressions = groupExpressions

  private def toSparkAgg(dAggSpec : AggregationSpec) : Option[AggregateFunction] = {
    val dOpAttr = druidOpSchema.druidAttrMap(dAggSpec.name)
    dAggSpec match
    {
      case FunctionAggregationSpec("count", nm, _) => Some(Sum(attrRef(dOpAttr)))
      case FunctionAggregationSpec("longSum", nm, _) => Some(Sum(attrRef(dOpAttr)))
      case FunctionAggregationSpec("doubleSum", nm, _) => Some(Sum(attrRef(dOpAttr)))
      case FunctionAggregationSpec("longMin", nm, _) => Some(Min(attrRef(dOpAttr)))
      case FunctionAggregationSpec("doubleMin", nm, _) => Some(Min(attrRef(dOpAttr)))
      case FunctionAggregationSpec("longMax", nm, _) => Some(Max(attrRef(dOpAttr)))
      case FunctionAggregationSpec("doubleMax", nm, _) => Some(Max(attrRef(dOpAttr)))
      case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("MIN") =>
        Some(Min(attrRef(dOpAttr)))
      case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("MAX") =>
        Some(Max(attrRef(dOpAttr)))
      case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("SUM") =>
        Some(Sum(attrRef(dOpAttr)))
      case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("COUNT") =>
        Some(Sum(attrRef(dOpAttr)))
      case _ => None
    }
  }

  lazy val aggregatesO : Option[List[NamedExpression]] = Utils.sequence(
    dqb.aggregations.map { da =>
    val dOpAttr = druidOpSchema.druidAttrMap(da.name)
    toSparkAgg(da).map { aggFunc =>
      Alias(AggregateExpression(aggFunc, Complete, false), dOpAttr.name)(dOpAttr.exprId)
    }
  })

  def canBeExecutedInHistorical : Boolean = dqb.canPushToHistorical && aggregatesO.isDefined

  lazy val resultExpressions = groupExpressions ++ aggregatesO.get

  lazy val aggregateExpressions = resultExpressions.flatMap { expr =>
    expr.collect {
      case agg: AggregateExpression => agg
    }
  }.distinct

  lazy val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
    val aggregateFunction = agg.aggregateFunction
    val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
    (aggregateFunction, agg.isDistinct) -> attribute
  }.toMap

  lazy val rewrittenResultExpressions = resultExpressions.map { expr =>
    expr.transformDown {
      case aE@AggregateExpression(aggregateFunction, _, isDistinct, _) =>
        // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
        // so replace each aggregate expression by its corresponding attribute in the set:
        // aggregateFunctionToAttribute(aggregateFunction, isDistinct)
        aE.resultAttribute
      case expression => expression
    }.asInstanceOf[NamedExpression]
  }

  def aggOp(child : SparkPlan) : Seq[SparkPlan] = {
    org.apache.spark.sql.execution.aggregate.AggUtils.planAggregateWithoutPartial(
      namedGroupingExpressions,
        aggregateExpressions,
        rewrittenResultExpressions,
      child)
  }

} 
Example 12
Source File: FramelessInternals.scala    From frameless   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ObjectType
import scala.reflect.ClassTag

object FramelessInternals {
  def objectTypeFor[A](implicit classTag: ClassTag[A]): ObjectType = ObjectType(classTag.runtimeClass)

  def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = {
    ds.toDF.queryExecution.analyzed.resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver).getOrElse {
      throw new AnalysisException(
        s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames.mkString(", ")})""")
    }
  }

  def expr(column: Column): Expression = column.expr

  def column(column: Column): Expression = column.expr

  def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan

  def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution =
    ds.sparkSession.sessionState.executePlan(plan)

  def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = {
    val joined = executePlan(ds, plan)
    val leftOutput = joined.analyzed.output.take(leftPlan.output.length)
    val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length)

    Project(List(
      Alias(CreateStruct(leftOutput), "_1")(),
      Alias(CreateStruct(rightOutput), "_2")()
    ), joined.analyzed)
  }

  def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] =
    new Dataset(sqlContext, plan, encoder)

  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
    Dataset.ofRows(sparkSession, logicalPlan)

  // because org.apache.spark.sql.types.UserDefinedType is private[spark]
  type UserDefinedType[A >: Null] =  org.apache.spark.sql.types.UserDefinedType[A]

  
  case class DisambiguateRight[T](tagged: Expression) extends Expression with NonSQLExpression {
    def eval(input: InternalRow): Any = tagged.eval(input)
    def nullable: Boolean = false
    def children: Seq[Expression] = tagged :: Nil
    def dataType: DataType = tagged.dataType
    protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ???
    override def genCode(ctx: CodegenContext): ExprCode = tagged.genCode(ctx)
  }
} 
Example 13
Source File: QueryPlanSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.types.IntegerType

class QueryPlanSuite extends SparkFunSuite {

  test("origin remains the same after mapExpressions (SPARK-23823)") {
    CurrentOrigin.setPosition(0, 0)
    val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId)
    val query = plans.DslLogicalPlan(plans.table("table")).select(column)
    CurrentOrigin.reset()

    val mappedQuery = query mapExpressions {
      case _: Expression => Literal(1)
    }

    val mappedOrigin = mappedQuery.expressions.apply(0).origin
    assert(mappedOrigin == Origin.apply(Some(0), Some(0)))
  }

} 
Example 14
Source File: StarryTakeOrderedAndProjectExec.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.util.Utils


case class StarryTakeOrderedAndProjectExec(
                                            limit: Int,
                                            sortOrder: Seq[SortOrder],
                                            projectList: Seq[NamedExpression],
                                            child: SparkPlan) extends UnaryExecNode {

  override def output: Seq[Attribute] = {
    projectList.map(_.toAttribute)
  }

  override def executeCollect(): Array[InternalRow] = {
    val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
    val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
    if (projectList != child.output) {
      val proj = UnsafeProjection.create(projectList, child.output)
      data.map(r => proj(r).copy())
    } else {
      data
    }
  }

  protected override def doExecute(): RDD[InternalRow] = {
    val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
    val localTopK: RDD[InternalRow] = {
      child.execute().map(_.copy()).mapPartitions { iter =>
        org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord)
      }
    }
    localTopK.mapPartitions { iter =>
      val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
      if (projectList != child.output) {
        val proj = UnsafeProjection.create(projectList, child.output)
        topK.map(r => proj(r))
      } else {
        topK
      }
    }
  }

  override def outputOrdering: Seq[SortOrder] = sortOrder

  override def outputPartitioning: Partitioning = SinglePartition

  override def simpleString: String = {
    val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]")
    val outputString = Utils.truncatedString(output, "[", ",", "]")

    s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
  }
} 
Example 15
Source File: operators.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, LeafNode, LogicalPlan, Statistics, UnaryNode}

case class Recursion(name: String,
                     isLinear: Boolean,
                     left: LogicalPlan,
                     right: LogicalPlan,
                     partitioning: Seq[Int]) extends BinaryNode {
  // left is exitRules plan
  // right is recursive rules plan
  override def output: Seq[Attribute] = right.output
}

case class MutualRecursion(name: String,
                             isLinear: Boolean,
                             left: LogicalPlan,
                             right: LogicalPlan,
                             partitioning: Seq[Int]) extends BinaryNode {

  override def output: Seq[Attribute] = right.output

  override def children: Seq[LogicalPlan] = {
    if (left == null)
      Seq(right)
    else
      Seq(left, right)
  }

  override def generateTreeString(depth: Int,
                         lastChildren: Seq[Boolean],
                         builder: StringBuilder): StringBuilder = {
    if (depth > 0) {
      lastChildren.init.foreach { isLast =>
        val prefixFragment = if (isLast) "   " else ":  "
        builder.append(prefixFragment)
      }

      val branch = if (lastChildren.last) "+- " else ":- "
      builder.append(branch)
    }

    builder.append(simpleString)
    builder.append("\n")

    if (children.nonEmpty) {
      val exitRule = children.init
      if (exitRule != null)
        exitRule.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
      children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
    }

    builder
  }
}

case class LinearRecursiveRelation(_name: String, output: Seq[Attribute], partitioning: Seq[Int]) extends LeafNode {
  override def statistics: Statistics = Statistics(Long.MaxValue)
  var name = _name
}

case class NonLinearRecursiveRelation(_name: String, output: Seq[Attribute], partitioning: Seq[Int]) extends LeafNode {
  override def statistics: Statistics = Statistics(Long.MaxValue)

  def name = "all_" + _name
}

case class MonotonicAggregate(groupingExpressions: Seq[Expression],
                              aggregateExpressions: Seq[NamedExpression],
                              child: LogicalPlan,
                              partitioning: Seq[Int]) extends UnaryNode {
  override lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved

  override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}

case class AggregateRecursion(name: String,
                              isLinear: Boolean,
                              left: LogicalPlan,
                              right: LogicalPlan,
                              partitioning: Seq[Int]) extends BinaryNode {
  // left is exitRules plan
  // right is recursive rules plan
  override def output: Seq[Attribute] = right.output
}

case class AggregateRelation(_name: String, output: Seq[Attribute], partitioning: Seq[Int]) extends LeafNode {
  override def statistics: Statistics = Statistics(Long.MaxValue)
  var name = _name
}


case class CacheHint(child: LogicalPlan) extends UnaryNode {
  override def output: Seq[Attribute] = child.output
} 
Example 16
Source File: ProjectNodeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
import org.apache.spark.sql.types.{IntegerType, StringType}


class ProjectNodeSuite extends LocalNodeTest {
  private val pieAttributes = Seq(
    AttributeReference("id", IntegerType)(),
    AttributeReference("age", IntegerType)(),
    AttributeReference("name", StringType)())

  private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = {
    val inputNode = new DummyNode(pieAttributes, inputData)
    val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2))
    val projectNode = new ProjectNode(conf, columns, inputNode)
    val expectedOutput = inputData.map { case (id, age, name) => (id, name) }
    val actualOutput = projectNode.collect().map { case row =>
      (row.getInt(0), row.getString(1))
    }
    assert(actualOutput === expectedOutput)
  }

  test("empty") {
    testProject()
  }

  test("basic") {
    testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray)
  }

}