org.apache.spark.sql.catalyst.expressions.BoundReference Scala Examples
The following examples show how to use org.apache.spark.sql.catalyst.expressions.BoundReference.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: GlobalSapSQLContext.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import java.io.File import com.sap.spark.util.TestUtils import com.sap.spark.{GlobalSparkContext, WithSQLContext} import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} import org.apache.spark.unsafe.types._ import org.apache.spark.sql.types._ import org.scalatest.Suite import scala.io.Source trait GlobalSapSQLContext extends GlobalSparkContext with WithSQLContext { self: Suite => override implicit def sqlContext: SQLContext = GlobalSapSQLContext._sqlc override protected def setUpSQLContext(): Unit = GlobalSapSQLContext.init(sc) override protected def tearDownSQLContext(): Unit = GlobalSapSQLContext.reset() def getDataFrameFromSourceFile(sparkSchema: StructType, path: File): DataFrame = { val conversions = sparkSchema.toSeq.zipWithIndex.map({ case (field, index) => Cast(BoundReference(index, StringType, nullable = true), field.dataType) }) val data = Source.fromFile(path) .getLines() .map({ line => val stringRow = InternalRow.fromSeq(line.split(",", -1).map(UTF8String.fromString)) Row.fromSeq(conversions.map({ c => c.eval(stringRow) })) }) val rdd = sc.parallelize(data.toSeq, numberOfSparkWorkers) sqlContext.createDataFrame(rdd, sparkSchema) } } object GlobalSapSQLContext { private var _sqlc: SQLContext = _ private def init(sc: SparkContext): Unit = if (_sqlc == null) { _sqlc = TestUtils.newSQLContext(sc) } private def reset(): Unit = { if (_sqlc != null) { _sqlc.catalog.unregisterAllTables() } } }
Example 2
Source File: TypedExpressionEncoder.scala From frameless with Apache License 2.0 | 5 votes |
package frameless import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If, Literal} import org.apache.spark.sql.types.StructType object TypedExpressionEncoder { def targetStructType[A](encoder: TypedEncoder[A]): StructType = { encoder.catalystRepr match { case x: StructType => if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true))) else x case dt => new StructType().add("_1", dt, nullable = encoder.nullable) } } def apply[T: TypedEncoder]: ExpressionEncoder[T] = { val encoder = TypedEncoder[T] val schema = targetStructType(encoder) val in = BoundReference(0, encoder.jvmRepr, encoder.nullable) val (out, toRowExpressions) = encoder.toCatalyst(in) match { case If(_, _, x: CreateNamedStruct) => val out = BoundReference(0, encoder.catalystRepr, encoder.nullable) (out, x.flatten) case other => val out = GetColumnByOrdinal(0, encoder.catalystRepr) (out, CreateNamedStruct(Literal("_1") :: other :: Nil).flatten) } new ExpressionEncoder[T]( schema = schema, flat = false, serializer = toRowExpressions, deserializer = encoder.fromCatalyst(out), clsTag = encoder.classTag ) } }
Example 3
Source File: GenerateUnsafeProjectionSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenerateUnsafeProjectionSuite extends SparkFunSuite { test("Test unsafe projection string access pattern") { val dataType = (new StructType).add("a", StringType) val exprs = BoundReference(0, dataType, nullable = true) :: Nil val projection = GenerateUnsafeProjection.generate(exprs) val result = projection.apply(InternalRow(AlwaysNull)) assert(!result.isNullAt(0)) assert(result.getStruct(0, 1).isNullAt(0)) } } object AlwaysNull extends InternalRow { override def numFields: Int = 1 override def setNullAt(i: Int): Unit = {} override def copy(): InternalRow = this override def anyNull: Boolean = true override def isNullAt(ordinal: Int): Boolean = true override def update(i: Int, value: Any): Unit = notSupported override def getBoolean(ordinal: Int): Boolean = notSupported override def getByte(ordinal: Int): Byte = notSupported override def getShort(ordinal: Int): Short = notSupported override def getInt(ordinal: Int): Int = notSupported override def getLong(ordinal: Int): Long = notSupported override def getFloat(ordinal: Int): Float = notSupported override def getDouble(ordinal: Int): Double = notSupported override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported override def getUTF8String(ordinal: Int): UTF8String = notSupported override def getBinary(ordinal: Int): Array[Byte] = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = notSupported override def getMap(ordinal: Int): MapData = notSupported override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported private def notSupported: Nothing = throw new UnsupportedOperationException }
Example 4
Source File: ComplexDataSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import scala.collection._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} import org.apache.spark.unsafe.types.UTF8String class ComplexDataSuite extends SparkFunSuite { def utf8(str: String): UTF8String = UTF8String.fromString(str) test("inequality tests for MapData") { // test data val testMap1 = Map(utf8("key1") -> 1) val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2) val testMap3 = Map(utf8("key1") -> 1) val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2) // ArrayBasedMapData val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) assert(testArrayMap1 !== testArrayMap3) assert(testArrayMap2 !== testArrayMap4) // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) val row = new GenericInternalRow(1) def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { row.update(0, map) val unsafeRow = unsafeConverter.apply(row) unsafeRow.getMap(0).copy } assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) } test("GenericInternalRow.copy return a new instance that is independent from the old one") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0))) val copiedGenericRow = genericRow.copy() assert(copiedGenericRow.getString(0) == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied internal row should not be changed externally. assert(copiedGenericRow.getString(0) == "a") } test("SpecificMutableRow.copy return a new instance that is independent from the old one") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val mutableRow = new SpecificInternalRow(Seq(StringType)) mutableRow(0) = unsafeRow.getUTF8String(0) val copiedMutableRow = mutableRow.copy() assert(copiedMutableRow.getString(0) == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied internal row should not be changed externally. assert(copiedMutableRow.getString(0) == "a") } test("GenericArrayData.copy return a new instance that is independent from the old one") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0))) val copiedGenericArray = genericArray.copy() assert(copiedGenericArray.getUTF8String(0).toString == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied array data should not be changed externally. assert(copiedGenericArray.getUTF8String(0).toString == "a") } test("copy on nested complex type") { val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) val unsafeRow = project.apply(InternalRow(utf8("a"))) val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0)))) val copied = arrayOfRow.copy() assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") project.apply(InternalRow(UTF8String.fromString("b"))) // The copied data should not be changed externally. assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") } }
Example 5
Source File: ExpandSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.IntegerType class ExpandSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder private def testExpand(f: SparkPlan => SparkPlan): Unit = { val input = (1 to 1000).map(Tuple1.apply) val projections = Seq.tabulate(2) { i => Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil } val attributes = projections.head.map(_.toAttribute) checkAnswer( input.toDF(), plan => Expand(projections, attributes, f(plan)), input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) ) } test("inheriting child row type") { val exprs = AttributeReference("a", IntegerType, false)() :: Nil val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") } test("expanding UnsafeRows") { testExpand(ConvertToUnsafe) } test("expanding SafeRows") { testExpand(identity) } }
Example 6
Source File: HiveAcidUtils.scala From spark-acid with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive import scala.collection.JavaConverters._ import com.qubole.spark.hiveacid.hive.HiveAcidMetadata import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTablePartition, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, InterpretedPredicate, PrettyAttribute} object HiveAcidUtils { def prunePartitionsByFilter( hiveAcidMetadata: HiveAcidMetadata, inputPartitions: Seq[CatalogTablePartition], predicates: Option[Expression], defaultTimeZoneId: String): Seq[CatalogTablePartition] = { if (predicates.isEmpty) { inputPartitions } else { val partitionSchema = hiveAcidMetadata.partitionSchema val partitionColumnNames = hiveAcidMetadata.partitionSchema.fieldNames.toSet val nonPartitionPruningPredicates = predicates.filterNot { _.references.map(_.name).toSet.subsetOf(partitionColumnNames) } if (nonPartitionPruningPredicates.nonEmpty) { throw new AnalysisException("Expected only partition pruning predicates: " + nonPartitionPruningPredicates) } val boundPredicate = InterpretedPredicate.create(predicates.get.transform { case att: Attribute => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) }) inputPartitions.filter { p => boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) } } } def convertToCatalogTablePartition(hp: com.qubole.shaded.hadoop.hive.ql.metadata.Partition): CatalogTablePartition = { val apiPartition = hp.getTPartition val properties: Map[String, String] = if (hp.getParameters != null) { hp.getParameters.asScala.toMap } else { Map.empty } CatalogTablePartition( spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty), storage = CatalogStorageFormat( locationUri = Option(CatalogUtils.stringToURI(apiPartition.getSd.getLocation)), inputFormat = Option(apiPartition.getSd.getInputFormat), outputFormat = Option(apiPartition.getSd.getOutputFormat), serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), compressed = apiPartition.getSd.isCompressed, properties = Option(apiPartition.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull), createTime = apiPartition.getCreateTime.toLong * 1000, lastAccessTime = apiPartition.getLastAccessTime.toLong * 1000, parameters = properties, stats = None) // TODO: need to implement readHiveStats } }