org.apache.spark.InterruptibleIterator Scala Examples

The following examples show how to use org.apache.spark.InterruptibleIterator. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: DataSourceRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.reflect.ClassTag

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.v2.reader.InputPartition

class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T])
  extends Partition with Serializable

class DataSourceRDD[T: ClassTag](
    sc: SparkContext,
    @transient private val inputPartitions: Seq[InputPartition[T]])
  extends RDD[T](sc, Nil) {

  override protected def getPartitions: Array[Partition] = {
    inputPartitions.zipWithIndex.map {
      case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition)
    }.toArray
  }

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition
        .createPartitionReader()
    context.addTaskCompletionListener[Unit](_ => reader.close())
    val iter = new Iterator[T] {
      private[this] var valuePrepared = false

      override def hasNext: Boolean = {
        if (!valuePrepared) {
          valuePrepared = reader.next()
        }
        valuePrepared
      }

      override def next(): T = {
        if (!hasNext) {
          throw new java.util.NoSuchElementException("End of stream")
        }
        valuePrepared = false
        reader.get()
      }
    }
    new InterruptibleIterator(context, iter)
  }

  override def getPreferredLocations(split: Partition): Seq[String] = {
    split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations()
  }
} 
Example 2
Source File: SplashShuffleReader.scala    From splash   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle

import org.apache.spark.internal.Logging
import org.apache.spark.storage.BlockId
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}


private[spark] class SplashShuffleReader[K, V](
    resolver: SplashShuffleBlockResolver,
    handle: BaseShuffleHandle[K, _, V],
    startPartition: Int,
    endPartition: Int,
    context: TaskContext,
    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
    extends ShuffleReader[K, V] with Logging {

  private val dep = handle.dependency

  private type Pair = (Any, Any)
  private type KCPair = (K, V)
  private type KCIterator = Iterator[KCPair]

  override def read(): KCIterator = {
    val shuffleBlocks = mapOutputTracker.getMapSizesByExecutorId(
      handle.shuffleId, startPartition, endPartition)
        .flatMap(_._2)
    readShuffleBlocks(shuffleBlocks)
  }

  def readShuffleBlocks(shuffleBlocks: Seq[(BlockId, Long)]): KCIterator =
    readShuffleBlocks(shuffleBlocks.iterator)

  def readShuffleBlocks(shuffleBlocks: Iterator[(BlockId, Long)]): KCIterator = {
    val taskMetrics = context.taskMetrics()
    val serializer = SplashSerializer(dep)

    val nonEmptyBlocks = shuffleBlocks.filter(_._2 > 0).map(_._1)
    val fetcherIterator = SplashShuffleFetcherIterator(resolver, nonEmptyBlocks)

    def getAggregatedIterator(iterator: Iterator[Pair]): KCIterator = {
      dep.aggregator match {
        case Some(agg) =>
          val aggregator = new SplashAggregator(agg)
          if (dep.mapSideCombine) {
            // We are reading values that are already combined
            val combinedKeyValuesIterator = iterator.asInstanceOf[Iterator[(K, V)]]
            aggregator.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            // We don't know the value type, but also don't care -- the dependency *should*
            // have made sure its compatible w/ this aggregator, which will convert the value
            // type to the combined type C
            val keyValuesIterator = iterator.asInstanceOf[Iterator[(K, Nothing)]]
            aggregator.combineValuesByKey(keyValuesIterator, context)
          }
        case None =>
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
          iterator.asInstanceOf[KCIterator]
      }
    }

    def getSortedIterator(iterator: KCIterator): KCIterator = {
      // Sort the output if there is a sort ordering defined.
      dep.keyOrdering match {
        case Some(keyOrd: Ordering[K]) =>
          // Create an ExternalSorter to sort the data.
          val sorter = new SplashSorter[K, V, V](
            context, ordering = Some(keyOrd), serializer = serializer)
          sorter.insertAll(iterator)
          sorter.updateTaskMetrics()
          sorter.completionIterator
        case None =>
          iterator
      }
    }

    val metricIter = fetcherIterator.flatMap(
      _.asMetricIterator(serializer, taskMetrics))

    // An interruptible iterator must be used here in order to support task cancellation
    getSortedIterator(
      getAggregatedIterator(
        new InterruptibleIterator[Pair](context, metricIter)))
  }
} 
Example 3
Source File: HashShuffleReader.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
    handle: BaseShuffleHandle[K, _, C],
    startPartition: Int,
    endPartition: Int,
    context: TaskContext)
  extends ShuffleReader[K, C]
{
  require(endPartition == startPartition + 1,
    "Hash shuffle currently only supports fetching one partition")

  private val dep = handle.dependency

  
  override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)
    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
      } else {
        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
        sorter.iterator
      case None =>
        aggregatedIter
    }
  }
} 
Example 4
Source File: MemsqlRDD.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, PreparedStatement, ResultSet}

import com.memsql.spark.SQLGen.VariableList
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}

case class MemsqlRDD(query: String,
                     variables: VariableList,
                     options: MemsqlOptions,
                     schema: StructType,
                     expectedOutput: Seq[Attribute],
                     @transient val sc: SparkContext)
    extends RDD[Row](sc, Nil) {

  override protected def getPartitions: Array[Partition] =
    MemsqlQueryHelpers.GetPartitions(options, query, variables)

  override def compute(rawPartition: Partition, context: TaskContext): Iterator[Row] = {
    var closed                     = false
    var rs: ResultSet              = null
    var stmt: PreparedStatement    = null
    var conn: Connection           = null
    var partition: MemsqlPartition = rawPartition.asInstanceOf[MemsqlPartition]

    def tryClose(name: String, what: AutoCloseable): Unit = {
      try {
        if (what != null) { what.close() }
      } catch {
        case e: Exception => logWarning(s"Exception closing $name", e)
      }
    }

    def close(): Unit = {
      if (closed) { return }
      tryClose("resultset", rs)
      tryClose("statement", stmt)
      tryClose("connection", conn)
      closed = true
    }

    context.addTaskCompletionListener { context =>
      close()
    }

    conn = JdbcUtils.createConnectionFactory(partition.connectionInfo)()
    stmt = conn.prepareStatement(partition.query)
    JdbcHelpers.fillStatement(stmt, partition.variables)
    rs = stmt.executeQuery()

    var rowsIter = JdbcUtils.resultSetToRows(rs, schema)

    if (expectedOutput.nonEmpty) {
      val schemaDatatypes   = schema.map(_.dataType)
      val expectedDatatypes = expectedOutput.map(_.dataType)

      if (schemaDatatypes != expectedDatatypes) {
        val columnEncoders = schemaDatatypes.zip(expectedDatatypes).zipWithIndex.map {
          case ((_: StringType, _: NullType), _)     => ((_: Row) => null)
          case ((_: ShortType, _: BooleanType), i)   => ((r: Row) => r.getShort(i) != 0)
          case ((_: IntegerType, _: BooleanType), i) => ((r: Row) => r.getInt(i) != 0)
          case ((_: LongType, _: BooleanType), i)    => ((r: Row) => r.getLong(i) != 0)

          case ((l, r), i) => {
            options.assert(l == r, s"MemsqlRDD: unable to encode ${l} into ${r}")
            ((r: Row) => r.get(i))
          }
        }

        rowsIter = rowsIter
          .map(row => Row.fromSeq(columnEncoders.map(_(row))))
      }
    }

    CompletionIterator[Row, Iterator[Row]](new InterruptibleIterator[Row](context, rowsIter), close)
  }

} 
Example 5
Source File: HashShuffleReader.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
    handle: BaseShuffleHandle[K, _, C],
    startPartition: Int,
    endPartition: Int,
    context: TaskContext)
  extends ShuffleReader[K, C]
{
  require(endPartition == startPartition + 1,
    "Hash shuffle currently only supports fetching one partition")

  private val dep = handle.dependency

  
  override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)
    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
      } else {
        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
        sorter.iterator
      case None =>
        aggregatedIter
    }
  }
} 
Example 6
Source File: DataSourceRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory

class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
  extends Partition with Serializable

class DataSourceRDD[T: ClassTag](
    sc: SparkContext,
    @transient private val readerFactories: java.util.List[DataReaderFactory[T]])
  extends RDD[T](sc, Nil) {

  override protected def getPartitions: Array[Partition] = {
    readerFactories.asScala.zipWithIndex.map {
      case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
    }.toArray
  }

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
    context.addTaskCompletionListener(_ => reader.close())
    val iter = new Iterator[T] {
      private[this] var valuePrepared = false

      override def hasNext: Boolean = {
        if (!valuePrepared) {
          valuePrepared = reader.next()
        }
        valuePrepared
      }

      override def next(): T = {
        if (!hasNext) {
          throw new java.util.NoSuchElementException("End of stream")
        }
        valuePrepared = false
        reader.get()
      }
    }
    new InterruptibleIterator(context, iter)
  }

  override def getPreferredLocations(split: Partition): Seq[String] = {
    split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations()
  }
} 
Example 7
Source File: NewHBaseRDD.scala    From hbase-connectors   with Apache License 2.0 5 votes vote down vote up
package org.apache.hadoop.hbase.spark

import org.apache.hadoop.conf.Configuration
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hadoop.mapreduce.InputFormat
import org.apache.spark.rdd.NewHadoopRDD
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}

@InterfaceAudience.Public
class NewHBaseRDD[K,V](@transient val sc : SparkContext,
                       @transient val inputFormatClass: Class[_ <: InputFormat[K, V]],
                       @transient val keyClass: Class[K],
                       @transient val valueClass: Class[V],
                       @transient private val __conf: Configuration,
                       val hBaseContext: HBaseContext) extends NewHadoopRDD(sc, inputFormatClass, keyClass, valueClass, __conf) {

  override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
    hBaseContext.applyCreds()
    super.compute(theSplit, context)
  }
}