org.apache.spark.sql.types.NumericType Scala Examples

The following examples show how to use org.apache.spark.sql.types.NumericType. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: InteractionOp.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.ops.feature

import ml.bundle.DataShape
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.{OpModel, OpNode}
import ml.combust.mleap.core.annotation.SparkCode
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.bundle._
import org.apache.spark.ml.feature.Interaction
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.mleap.TypeConverters._
import ml.combust.mleap.runtime.types.BundleTypeConverters._
import org.apache.spark.sql.types.{BooleanType, NumericType}


class InteractionOp extends SimpleSparkOp[Interaction] {
  override val Model: OpModel[SparkBundleContext, Interaction] = new OpModel[SparkBundleContext, Interaction] {
    override val klazz: Class[Interaction] = classOf[Interaction]

    override def opName: String = Bundle.BuiltinOps.feature.interaction

    override def store(model: Model, obj: Interaction)
                      (implicit context: BundleContext[SparkBundleContext]): Model = {
      assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))

      val dataset = context.context.dataset.get
      val spec = buildSpec(obj.getInputCols, dataset)
      val inputShapes = obj.getInputCols.map(v => sparkToMleapDataShape(dataset.schema(v), dataset): DataShape)

      val m = model.withValue("num_inputs", Value.int(spec.length)).
        withValue("input_shapes", Value.dataShapeList(inputShapes))

      spec.zipWithIndex.foldLeft(m) {
        case (m2, (numFeatures, index)) => m2.withValue(s"num_features$index", Value.intList(numFeatures))
      }
    }

    override def load(model: Model)
                     (implicit context: BundleContext[SparkBundleContext]): Interaction = {
      // No need to do anything here, everything is handled through Spark meta data
      new Interaction()
    }

    @SparkCode(uri = "https://github.com/apache/spark/blob/branch-2.1/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala")
    private def buildSpec(inputCols: Array[String], dataset: DataFrame): Array[Array[Int]] = {
      def getNumFeatures(attr: Attribute): Int = {
        attr match {
          case nominal: NominalAttribute =>
            math.max(1, nominal.getNumValues.getOrElse(
              throw new IllegalArgumentException("Nominal features must have attr numValues defined.")))
          case _ =>
            1  // numeric feature
        }
      }

      inputCols.map(dataset.schema.apply).map { f =>
        f.dataType match {
          case _: NumericType | BooleanType =>
            Array(getNumFeatures(Attribute.fromStructField(f)))
          case _: VectorUDT =>
            val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse(
              throw new IllegalArgumentException("Vector attributes must be defined for interaction."))
            attrs.map(getNumFeatures)
        }
      }
    }
  }

  override def sparkLoad(uid: String, shape: NodeShape, model: Interaction): Interaction = {
    new Interaction(uid = uid)
  }

  override def sparkInputs(obj: Interaction): Seq[ParamSpec] = {
    Seq("input" -> obj.inputCols)
  }

  override def sparkOutputs(obj: Interaction): Seq[SimpleParamSpec] = {
    Seq("output" -> obj.outputCol)
  }
} 
Example 2
Source File: MathUnary.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.feature

import ml.combust.mleap.core.feature.{MathUnaryModel, UnaryOperation}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
import org.apache.spark.sql.functions.udf


    private val className = classOf[MathUnary].getName

    override def load(path: String): MathUnary = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

      val dataPath = new Path(path, "data").toString

      val data = sparkSession.read.parquet(dataPath).select("operation").head()
      val operation = data.getAs[String](0)

      val model = MathUnaryModel(UnaryOperation.forName(operation))
      val transformer = new MathUnary(metadata.uid, model)

      metadata.getAndSetParams(transformer)
      transformer
    }
  }

} 
Example 3
Source File: ProtoConversions.scala    From tensorframes   with Apache License 2.0 5 votes vote down vote up
package org.tensorframes.dsl

import org.tensorflow.framework.{AttrValue, DataType, NodeDef}
import org.tensorframes.impl.SupportedOperations

import org.apache.spark.sql.types.NumericType


private[tensorframes] object ProtoConversions {
  def getDType(nodeDef: NodeDef): DataType = {
    val opt = Option(nodeDef.getAttr.get("T")).orElse(Option(nodeDef.getAttr.get("dtype")))
    val v = opt.getOrElse(throw new Exception(s"Neither 'T' no 'dtype' was found in $nodeDef"))
    v.getType
  }

  def getDType(sqlType: NumericType): DataType = {
    SupportedOperations.opsFor(sqlType).tfType
  }

  def sqlTypeToAttrValue(sqlType: NumericType): AttrValue = {
    AttrValue.newBuilder().setType(getDType(sqlType)).build()
  }

  def dataTypeToAttrValue(dataType: DataType): AttrValue = {
    AttrValue.newBuilder().setType(dataType).build()
  }

} 
Example 4
Source File: GenerateOrdering.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{BinaryType, StringType, NumericType}


object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
  import scala.reflect.runtime.{universe => ru}
  import scala.reflect.runtime.universe._

 protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
    in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

  protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
    in.map(BindReferences.bindReference(_, inputSchema))

  protected def create(ordering: Seq[SortOrder]): Ordering[Row] = {
    val a = newTermName("a")
    val b = newTermName("b")
    val comparisons = ordering.zipWithIndex.map { case (order, i) =>
      val evalA = expressionEvaluator(order.child)
      val evalB = expressionEvaluator(order.child)

      val compare = order.child.dataType match {
        case BinaryType =>
          q"""
          val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm}
          val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm}
          var i = 0
          while (i < x.length && i < y.length) {
            val res = x(i).compareTo(y(i))
            if (res != 0) return res
            i = i+1
          }
          return x.length - y.length
          """
        case _: NumericType =>
          q"""
          val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
          if(comp != 0) {
            return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
          }
          """
        case StringType =>
          if (order.direction == Ascending) {
            q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
          } else {
            q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
          }
      }

      q"""
        i = $a
        ..${evalA.code}
        i = $b
        ..${evalB.code}
        if (${evalA.nullTerm} && ${evalB.nullTerm}) {
          // Nothing
        } else if (${evalA.nullTerm}) {
          return ${if (order.direction == Ascending) q"-1" else q"1"}
        } else if (${evalB.nullTerm}) {
          return ${if (order.direction == Ascending) q"1" else q"-1"}
        } else {
          $compare
        }
      """
    }

    val q"class $orderingName extends $orderingType { ..$body }" = reify {
      class SpecificOrdering extends Ordering[Row] {
        val o = ordering
      }
    }.tree.children.head

    val code = q"""
      class $orderingName extends $orderingType {
        ..$body
        def compare(a: $rowType, b: $rowType): Int = {
          var i: $rowType = null // Holds current row being evaluated.
          ..$comparisons
          return 0
        }
      }
      new $orderingName()
      """
    logDebug(s"Generated Ordering: $code")
    toolBox.eval(code).asInstanceOf[Ordering[Row]]
  }
} 
Example 5
Source File: ArrangePostprocessor.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.postprocessors

import com.typesafe.config.Config
import it.agilelab.bigdata.DataQuality.checks.CheckResult
import it.agilelab.bigdata.DataQuality.metrics.MetricResult
import it.agilelab.bigdata.DataQuality.sources.HdfsFile
import it.agilelab.bigdata.DataQuality.targets.HdfsTargetConfig
import it.agilelab.bigdata.DataQuality.utils
import it.agilelab.bigdata.DataQuality.utils.DQSettings
import it.agilelab.bigdata.DataQuality.utils.io.{HdfsReader, HdfsWriter}
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, NumericType}
import org.apache.spark.sql.{Column, DataFrame, SQLContext}

import scala.collection.JavaConversions._

final class ArrangePostprocessor(config: Config, settings: DQSettings)
    extends BasicPostprocessor(config, settings) {

  private case class ColumnSelector(name: String, tipo: Option[String] = None, format: Option[String] = None, precision: Option[Integer] = None) {
    def toColumn()(implicit df: DataFrame): Column = {

      val dataType: Option[NumericType with Product with Serializable] =
        tipo.getOrElse("").toUpperCase match {
          case "DOUBLE" => Some(DoubleType)
          case "INT"    => Some(IntegerType)
          case "LONG"   => Some(LongType)
          case _        => None
        }

      import org.apache.spark.sql.functions.format_number
      import org.apache.spark.sql.functions.format_string

      (dataType, precision, format) match {
        case (Some(dt), None, None) => df(name).cast(dt)
        case(Some(dt), None, Some(f)) => format_string(f, df(name).cast(dt)).alias(name)
        case (Some(dt), Some(p),None) => format_number(df(name).cast(dt), p).alias(name)
        case (None, Some(p), None) => format_number(df(name), p).alias(name)
        case (None, None, Some(f)) => format_string(f, df(name)).alias(name)
        case _ => df(name)
      }
    }
  }

  private val vs = config.getString("source")
  private val target: HdfsTargetConfig = {
    val conf = config.getConfig("saveTo")
    utils.parseTargetConfig(conf)(settings).get
  }

  private val columns: Seq[ColumnSelector] =
    config.getAnyRefList("columnOrder").map {
      case x: String => ColumnSelector(x)
      case x: java.util.HashMap[_, String] => {
        val (name, v) = x.head.asInstanceOf[String Tuple2 _]

        v match {
          case v: String =>
            ColumnSelector(name, Option(v))
          case v: java.util.HashMap[String, _] => {
            val k = v.head._1
            val f = v.head._2

            f match {
              case f: Integer =>
                ColumnSelector(name, Option(k), None, Option(f))
              case f: String =>
                ColumnSelector(name, Option(k), Option(f))
            }
          }
        }
      }
    }

  override def process(vsRef: Set[HdfsFile],
                       metRes: Seq[MetricResult],
                       chkRes: Seq[CheckResult])(
      implicit fs: FileSystem,
      sqlContext: SQLContext,
      settings: DQSettings): HdfsFile = {

    val reqVS: HdfsFile = vsRef.filter(vr => vr.id == vs).head
    implicit val df: DataFrame = HdfsReader.load(reqVS, settings.ref_date).head

    val arrangeDF = df.select(columns.map(_.toColumn): _*)

    HdfsWriter.saveVirtualSource(arrangeDF, target, settings.refDateString)(
      fs,
      sqlContext.sparkContext)

    new HdfsFile(target)
  }
} 
Example 6
Source File: MeanSubstitute.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.sql.dsl._
import io.projectglow.sql.util.RewriteAfterResolution


case class MeanSubstitute(array: Expression, missingValue: Expression)
    extends RewriteAfterResolution {

  override def children: Seq[Expression] = Seq(array, missingValue)

  def this(array: Expression) = {
    this(array, Literal(-1))
  }

  private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType

  // A value is considered missing if it is NaN, null or equal to the missing value parameter
  def isMissing(arrayElement: Expression): Predicate =
    IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue

  def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = {
    val sumName = Literal(UTF8String.fromString("sum"), StringType)
    val countName = Literal(UTF8String.fromString("count"), StringType)
    namedStruct(sumName, sumValue, countName, countValue)
  }

  // Update sum and count with array element if not missing
  def updateSumAndCountConditionally(
      stateStruct: Expression,
      arrayElement: Expression): Expression = {
    If(
      isMissing(arrayElement),
      // If value is missing, do not update sum and count
      stateStruct,
      // If value is not missing, add to sum and increment count
      createNamedStruct(
        stateStruct.getField("sum") + arrayElement,
        stateStruct.getField("count") + 1)
    )
  }

  // Calculate mean for imputation
  def calculateMean(stateStruct: Expression): Expression = {
    If(
      stateStruct.getField("count") > 0,
      // If non-missing values were found, calculate the average
      stateStruct.getField("sum") / stateStruct.getField("count"),
      // If all values were missing, substitute with missing value
      missingValue
    )
  }

  lazy val arrayMean: Expression = {
    // Sum and count of non-missing values
    array.aggregate(
      createNamedStruct(Literal(0d), Literal(0L)),
      updateSumAndCountConditionally,
      calculateMean
    )
  }

  def substituteWithMean(arrayElement: Expression): Expression = {
    If(isMissing(arrayElement), arrayMean, arrayElement)
  }

  override def rewrite: Expression = {
    if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.")
    }

    if (!missingValue.dataType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.")
    }

    // Replace missing values with the provided strategy
    array.arrayTransform(substituteWithMean(_))
  }
} 
Example 7
Source File: HashMapIndexedRelation.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.index

import org.apache.spark.sql.simba.partitioner.HashPartition
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.NumericType
import org.apache.spark.storage.StorageLevel


private[simba] case class HashMapIndexedRelation(output: Seq[Attribute], child: SparkPlan,
    table_name: Option[String], column_keys: List[Attribute], index_name: String)(var _indexedRDD: IndexedRDD = null)
  extends IndexedRelation with MultiInstanceRelation {

  require(column_keys.length == 1)
  require(column_keys.head.dataType.isInstanceOf[NumericType])
  if (_indexedRDD == null) {
    buildIndex()
  }

  private[simba] def buildIndex(): Unit = {
    val numShufflePartitions = simbaSession.sessionState.simbaConf.indexPartitions

    val dataRDD = child.execute().map(row => {
      val eval_key = BindReferences.bindReference(column_keys.head, child.output).eval(row)
      (eval_key, row)
    })

    val partitionedRDD = HashPartition(dataRDD, numShufflePartitions)
    val indexed = partitionedRDD.mapPartitions(iter => {
      val data = iter.toArray
      val index = HashMapIndex(data)
      Array(IPartition(data.map(_._2), index)).iterator
    }).persist(StorageLevel.MEMORY_AND_DISK_SER)

    indexed.setName(table_name.map(n => s"$n $index_name").getOrElse(child.toString))
    _indexedRDD = indexed
  }

  override def newInstance(): IndexedRelation = {
    HashMapIndexedRelation(output.map(_.newInstance()), child, table_name,
      column_keys, index_name)(_indexedRDD).asInstanceOf[this.type]
  }

  override def withOutput(new_output: Seq[Attribute]): IndexedRelation = {
    HashMapIndexedRelation(new_output, child, table_name, column_keys, index_name)(_indexedRDD)
  }
} 
Example 8
Source File: TreapIndexedRelation.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.index

import org.apache.spark.sql.simba.partitioner.RangePartition
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.NumericType
import org.apache.spark.storage.StorageLevel


private[simba] case class TreapIndexedRelation(output: Seq[Attribute], child: SparkPlan,
    table_name: Option[String], column_keys: List[Attribute], index_name: String)
    (var _indexedRDD: IndexedRDD = null, var range_bounds: Array[Double] = null)
  extends IndexedRelation with MultiInstanceRelation {

  require(column_keys.length == 1)
  require(column_keys.head.dataType.isInstanceOf[NumericType])
  val numShufflePartitions = simbaSession.sessionState.simbaConf.indexPartitions

  if (_indexedRDD == null) {
    buildIndex()
  }

  private[simba] def buildIndex(): Unit = {
    val dataRDD = child.execute().map(row => {
      val eval_key = BindReferences.bindReference(column_keys.head, child.output).eval(row)
        .asInstanceOf[Double]
      (eval_key, row)
    })

    val (partitionedRDD, tmp_bounds) = RangePartition.rowPartition(dataRDD, numShufflePartitions)
    range_bounds = tmp_bounds
    val indexed = partitionedRDD.mapPartitions(iter => {
      val data = iter.toArray
      val index = Treap(data)
      Array(IPartition(data.map(_._2), index)).iterator
    }).persist(StorageLevel.MEMORY_AND_DISK_SER)

    indexed.setName(table_name.map(n => s"$n $index_name").getOrElse(child.toString))
    _indexedRDD = indexed
  }

  override def newInstance(): IndexedRelation = {
    TreapIndexedRelation(output.map(_.newInstance()), child, table_name,
      column_keys, index_name)(_indexedRDD)
      .asInstanceOf[this.type]
  }

  override def withOutput(new_output: Seq[Attribute]): IndexedRelation = {
    TreapIndexedRelation(new_output, child, table_name,
      column_keys, index_name)(_indexedRDD, range_bounds)
  }
} 
Example 9
Source File: TreeMapIndexedRelation.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.index

import org.apache.spark.sql.simba.partitioner.RangePartition
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.NumericType
import org.apache.spark.storage.StorageLevel


private[simba] case class TreeMapIndexedRelation(output: Seq[Attribute], child: SparkPlan,
    table_name: Option[String], column_keys: List[Attribute], index_name: String)
    (var _indexedRDD: IndexedRDD = null, var range_bounds: Array[Double] = null)
  extends IndexedRelation with MultiInstanceRelation {

  require(column_keys.length == 1)
  require(column_keys.head.dataType.isInstanceOf[NumericType])

  if (_indexedRDD == null) {
    buildIndex()
  }

  private[simba] def buildIndex(): Unit = {
    val numShufflePartitions = simbaSession.sessionState.simbaConf.indexPartitions

    val dataRDD = child.execute().map(row => {
      val eval_key = BindReferences.bindReference(column_keys.head, child.output).eval(row)
        .asInstanceOf[Double]
      (eval_key, row)
    })

    val (partitionedRDD, tmp_bounds) = RangePartition.rowPartition(dataRDD, numShufflePartitions)
    range_bounds = tmp_bounds
    val indexed = partitionedRDD.mapPartitions(iter => {
      val data = iter.toArray
      val index = TreeMapIndex(data)
      Array(IPartition(data.map(_._2), index)).iterator
    }).persist(StorageLevel.MEMORY_AND_DISK_SER)

    indexed.setName(table_name.map(n => s"$n $index_name").getOrElse(child.toString))
    _indexedRDD = indexed
  }

  override def newInstance(): IndexedRelation = {
    TreeMapIndexedRelation(output.map(_.newInstance()), child, table_name,
      column_keys, index_name)(_indexedRDD)
      .asInstanceOf[this.type]
  }

  override def withOutput(new_output: Seq[Attribute]): IndexedRelation = {
    TreeMapIndexedRelation(new_output, child, table_name,
      column_keys, index_name)(_indexedRDD, range_bounds)
  }
} 
Example 10
Source File: RTreeIndexedRelation.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.index

import org.apache.spark.sql.simba.ShapeType
import org.apache.spark.sql.simba.partitioner.STRPartition
import org.apache.spark.sql.simba.util.ShapeUtils
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.NumericType
import org.apache.spark.storage.StorageLevel


private[simba] case class RTreeIndexedRelation(output: Seq[Attribute], child: SparkPlan, table_name: Option[String],
    column_keys: List[Attribute], index_name: String)(var _indexedRDD: IndexedRDD = null, var global_rtree: RTree = null)
  extends IndexedRelation with MultiInstanceRelation {

  var isPoint = false

  private def checkKeys: Boolean = {
    if (column_keys.length > 1) {
      for (i <- column_keys.indices)
        if (!column_keys(i).dataType.isInstanceOf[NumericType]) {
          return false
        }
      true
    } else { // length = 1; we do not support one dimension R-tree
      column_keys.head.dataType match {
        case t: ShapeType =>
          isPoint = true
          true
        case _ => false
      }
    }
  }
  require(checkKeys)

  val dimension = ShapeUtils.getPointFromRow(child.execute().first(), column_keys, child, isPoint).coord.length

  if (_indexedRDD == null) {
    buildIndex()
  }

  private[simba] def buildIndex(): Unit = {
    val numShufflePartitions = simbaSession.sessionState.simbaConf.indexPartitions
    val maxEntriesPerNode = simbaSession.sessionState.simbaConf.maxEntriesPerNode
    val sampleRate = simbaSession.sessionState.simbaConf.sampleRate
    val transferThreshold = simbaSession.sessionState.simbaConf.transferThreshold
    val dataRDD = child.execute().map(row => {
      (ShapeUtils.getPointFromRow(row, column_keys, child, isPoint), row)
    })

    val max_entries_per_node = maxEntriesPerNode
    val (partitionedRDD, mbr_bounds) =
      STRPartition(dataRDD, dimension, numShufflePartitions, sampleRate, transferThreshold, max_entries_per_node)

    val indexed = partitionedRDD.mapPartitions { iter =>
      val data = iter.toArray
      var index: RTree = null
      if (data.length > 0) index = RTree(data.map(_._1).zipWithIndex, max_entries_per_node)
      Array(IPartition(data.map(_._2), index)).iterator
    }.persist(StorageLevel.MEMORY_AND_DISK_SER)

    val partitionSize = indexed.mapPartitions(iter => iter.map(_.data.length)).collect()

    global_rtree = RTree(mbr_bounds.zip(partitionSize)
      .map(x => (x._1._1, x._1._2, x._2)), max_entries_per_node)
    indexed.setName(table_name.map(n => s"$n $index_name").getOrElse(child.toString))
    _indexedRDD = indexed
  }

  override def newInstance(): IndexedRelation = {
    RTreeIndexedRelation(output.map(_.newInstance()), child, table_name,
      column_keys, index_name)(_indexedRDD).asInstanceOf[this.type]
  }

  override def withOutput(new_output: Seq[Attribute]): IndexedRelation = {
    RTreeIndexedRelation(new_output, child, table_name,
      column_keys, index_name)(_indexedRDD, global_rtree)
  }
}