org.apache.spark.sql.streaming.StreamingQuery Scala Examples

The following examples show how to use org.apache.spark.sql.streaming.StreamingQuery. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: MapGroupsWithStateApp.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.operations.stateful

import com.phylosoft.spark.learning.sql.streaming.domain.Model.{Event, SessionInfo, SessionUpdate}
import com.phylosoft.spark.learning.sql.streaming.monitoring.Monitoring
import com.phylosoft.spark.learning.sql.streaming.sink.StreamingSink
import com.phylosoft.spark.learning.sql.streaming.sink.console.ConsoleSink
import com.phylosoft.spark.learning.sql.streaming.source.rate.UserActionsRateSource
import com.phylosoft.spark.learning.{Logger, SparkSessionConfiguration}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, Trigger}

object MapGroupsWithStateApp
  extends App
    with SparkSessionConfiguration
    with GroupsWithStateFunction
    with Monitoring
    with Logger {

  val settings = Map("spark.app.name" -> "MapGroupsWithStateApp")

  spark.streams.addListener(simpleListener)

  val source = new UserActionsRateSource(spark)

  val userActions = source.loadUserActions()
  userActions.printSchema()

  import spark.implicits._

  val events = userActions
    .withColumnRenamed("userId", "sessionId")
    .withColumnRenamed("actionTime", "timestamp")
    .as[Event]
  events.printSchema()

  // Sessionize the events. Track number of events, start and end timestamps of session, and
  // and report session updates.
  val timeTimeoutMode = "ProcessingTime"

  val sessionUpdates = timeTimeoutMode match {
    case "ProcessingTime" => events
      .groupByKey(event => event.sessionId)
      .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) {
      sessionUpdate
    }
    case _ => events
      .withWatermark("timestamp", "2 seconds")
      .groupByKey(event => event.sessionId)
      .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.EventTimeTimeout) {
      sessionUpdate
    }
  }

  val sessions = sessionUpdates
    .select($"*")
    .where("expired == true")
  sessions.printSchema()

  // Start running the query that prints the session updates to the console
  val query = startStreamingSink(sessions, initStreamingSink)

  query.awaitTermination()

  private def startStreamingSink[T <: StreamingSink](data: DataFrame, sink: T) : StreamingQuery = {
    sink.writeStream(data)
  }

  private def initStreamingSink: StreamingSink = {
    import scala.concurrent.duration._
    new ConsoleSink(trigger = Trigger.ProcessingTime(2.seconds), outputMode = OutputMode.Append())
  }

} 
Example 2
Source File: MemorySink.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.sink.memory

import com.phylosoft.spark.learning.sql.streaming.sink.StreamingSink
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, Trigger}


class MemorySink(trigger: Trigger = Trigger.Once(),
                 outputMode: OutputMode = OutputMode.Update())
  extends StreamingSink {

  override def writeStream(data: DataFrame): StreamingQuery = {
    data.writeStream
      .format("memory")
      .trigger(trigger)
      .outputMode(outputMode)
      .option("checkpointLocation", checkpointLocation + "/memory")
      .start()
  }

} 
Example 3
Source File: DeltaSink.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.sink.delta

import com.phylosoft.spark.learning.sql.streaming.sink.StreamingSink
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, Trigger}


class DeltaSink(trigger: Trigger = Trigger.Once(),
                 outputMode: OutputMode = OutputMode.Append())
  extends StreamingSink {

  override def writeStream(data: DataFrame): StreamingQuery = {
    data.writeStream
      .format("delta")
      .trigger(trigger)
      .outputMode(outputMode)
      .option("checkpointLocation", checkpointLocation + "/tmp/delta/events")
      .start("/tmp/delta/events")
  }

} 
Example 4
Source File: ConsoleSink.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.sink.console

import com.phylosoft.spark.learning.sql.streaming.sink.StreamingSink
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, Trigger}

class ConsoleSink(trigger: Trigger = Trigger.Once(),
                  outputMode: OutputMode = OutputMode.Update())
  extends StreamingSink {

  override def writeStream(data: DataFrame): StreamingQuery = {
    data.writeStream
      .format("console")
      .trigger(trigger)
      .outputMode(outputMode)
      .option("checkpointLocation", checkpointLocation + "/console")
      .start()
  }

} 
Example 5
Source File: StreamingQueryWrapper.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.util.UUID

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus}


  def explainInternal(extended: Boolean): String = {
    streamingQuery.explainInternal(extended)
  }

  override def sparkSession: SparkSession = {
    streamingQuery.sparkSession
  }

  override def recentProgress: Array[StreamingQueryProgress] = {
    streamingQuery.recentProgress
  }

  override def status: StreamingQueryStatus = {
    streamingQuery.status
  }

  override def exception: Option[StreamingQueryException] = {
    streamingQuery.exception
  }
} 
Example 6
Source File: StreamMetadata.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.{InputStreamReader, OutputStreamWriter}
import java.nio.charset.StandardCharsets
import java.util.ConcurrentModificationException

import scala.util.control.NonFatal

import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileAlreadyExistsException, FSDataInputStream, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
import org.apache.spark.sql.streaming.StreamingQuery


  def write(
      metadata: StreamMetadata,
      metadataFile: Path,
      hadoopConf: Configuration): Unit = {
    var output: CancellableFSDataOutputStream = null
    try {
      val fileManager = CheckpointFileManager.create(metadataFile.getParent, hadoopConf)
      output = fileManager.createAtomic(metadataFile, overwriteIfPossible = false)
      val writer = new OutputStreamWriter(output)
      Serialization.write(metadata, writer)
      writer.close()
    } catch {
      case e: FileAlreadyExistsException =>
        if (output != null) {
          output.cancel()
        }
        throw new ConcurrentModificationException(
          s"Multiple streaming queries are concurrently using $metadataFile", e)
      case e: Throwable =>
        if (output != null) {
          output.cancel()
        }
        logError(s"Error writing stream metadata $metadata to $metadataFile", e)
        throw e
    }
  }
} 
Example 7
Source File: StreamMetadata.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.{InputStreamReader, OutputStreamWriter}
import java.nio.charset.StandardCharsets

import scala.util.control.NonFatal

import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.streaming.StreamingQuery


  def write(
      metadata: StreamMetadata,
      metadataFile: Path,
      hadoopConf: Configuration): Unit = {
    var output: FSDataOutputStream = null
    try {
      val fs = FileSystem.get(hadoopConf)
      output = fs.create(metadataFile)
      val writer = new OutputStreamWriter(output)
      Serialization.write(metadata, writer)
      writer.close()
    } catch {
      case NonFatal(e) =>
        logError(s"Error writing stream metadata $metadata to $metadataFile", e)
        throw e
    } finally {
      IOUtils.closeQuietly(output)
    }
  }
} 
Example 8
Source File: StreamMetadata.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.{InputStreamReader, OutputStreamWriter}
import java.nio.charset.StandardCharsets

import scala.util.control.NonFatal

import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.streaming.StreamingQuery


  def write(
      metadata: StreamMetadata,
      metadataFile: Path,
      hadoopConf: Configuration): Unit = {
    var output: FSDataOutputStream = null
    try {
      val fs = FileSystem.get(hadoopConf)
      output = fs.create(metadataFile)
      val writer = new OutputStreamWriter(output)
      Serialization.write(metadata, writer)
      writer.close()
    } catch {
      case NonFatal(e) =>
        logError(s"Error writing stream metadata $metadata to $metadataFile", e)
        throw e
    } finally {
      IOUtils.closeQuietly(output)
    }
  }
} 
Example 9
Source File: highcharts.scala    From spark-highcharts   with Apache License 2.0 5 votes vote down vote up
package com.knockdata.spark.highcharts

import com.knockdata.spark.highcharts.model._
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.zeppelin.spark.ZeppelinContext

import scala.collection.mutable

object highcharts {
  private def nextParagraphId(z: ZeppelinContext): String = {
    val currentParagraphId = z.getInterpreterContext.getParagraphId
    val paragraphIds = z.listParagraphs.toArray
    val currentIndex = paragraphIds.indexOf(currentParagraphId)
    paragraphIds(currentIndex + 1).toString
  }

  def streamingChart(seriesHolder: SeriesHolder,
                     zHolder: ZeppelinContextHolder,
                     chartParagraphId: String, outputMode: String="append"): StreamingQuery = {
    val chartId = seriesHolder.chartId
    Registry.put(s"$chartId-seriesHolder", seriesHolder)
    Registry.put(s"$chartId-z", zHolder)

    val writeStream = seriesHolder.dataFrame.writeStream
      .format(classOf[CustomSinkProvider].getCanonicalName)
      .option("chartId", chartId)
      .option("chartParagraphId", chartParagraphId)

    outputMode match {
      case "complete" =>
        Registry.put (s"$chartId-outputMode", new CompleteOutputMode())
        writeStream.outputMode("complete").start()
      case "append" =>
        Registry.put (s"$chartId-outputMode", new AppendOutputMode(200))
        writeStream.outputMode("append").start()
      case _ =>
        throw new Exception("outputMode must be either append or complete")
    }

  }

  def apply(seriesHolder: SeriesHolder,
            z: ZeppelinContext,
            outputMode: String = null): StreamingQuery = {
    val chartParagraph = nextParagraphId(z)
    streamingChart(seriesHolder, new ZeppelinContextHolder(z), chartParagraph, outputMode)
  }

  def apply(seriesHolder: SeriesHolder,
            z: ZeppelinContext,
            outputMode: String,
            chartParagraph: String): StreamingQuery = {
    streamingChart(seriesHolder, new ZeppelinContextHolder(z), chartParagraph, outputMode)
  }

  def apply(seriesHolders: SeriesHolder*): Highcharts = {
    val normalSeriesBuffer = mutable.Buffer[Series]()
    val drilldownSeriesBuffer = mutable.Buffer[Series]()
      for (holder <- seriesHolders) {
          val (normalSeriesList, drilldownSeriesList) = holder.result

          normalSeriesBuffer ++= normalSeriesList
          drilldownSeriesBuffer ++= drilldownSeriesList
      }

      new Highcharts(normalSeriesBuffer.toList).drilldown(drilldownSeriesBuffer.toList)
  }
} 
Example 10
Source File: StreamingQueryWrapper.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.util.UUID

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus}


  def explainInternal(extended: Boolean): String = {
    streamingQuery.explainInternal(extended)
  }

  override def sparkSession: SparkSession = {
    streamingQuery.sparkSession
  }

  override def recentProgress: Array[StreamingQueryProgress] = {
    streamingQuery.recentProgress
  }

  override def status: StreamingQueryStatus = {
    streamingQuery.status
  }

  override def exception: Option[StreamingQueryException] = {
    streamingQuery.exception
  }
} 
Example 11
Source File: StreamMetadata.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.{InputStreamReader, OutputStreamWriter}
import java.nio.charset.StandardCharsets

import scala.util.control.NonFatal

import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.streaming.StreamingQuery


  def write(
      metadata: StreamMetadata,
      metadataFile: Path,
      hadoopConf: Configuration): Unit = {
    var output: FSDataOutputStream = null
    try {
      val fs = metadataFile.getFileSystem(hadoopConf)
      output = fs.create(metadataFile)
      val writer = new OutputStreamWriter(output)
      Serialization.write(metadata, writer)
      writer.close()
    } catch {
      case NonFatal(e) =>
        logError(s"Error writing stream metadata $metadata to $metadataFile", e)
        throw e
    } finally {
      IOUtils.closeQuietly(output)
    }
  }
} 
Example 12
Source File: ElasticSink.scala    From Spark-Structured-Streaming-Examples   with Apache License 2.0 5 votes vote down vote up
package elastic

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import radio.{SimpleSongAggregation, Song}
import org.elasticsearch.spark.sql.streaming._
import org.elasticsearch.spark.sql._
import org.elasticsearch.spark.sql.streaming.EsSparkSqlStreamingSink

object ElasticSink {
  def writeStream(ds: Dataset[Song] ) : StreamingQuery = {
    ds   //Append output mode not supported when there are streaming aggregations on streaming DataFrames/DataSets without watermark
      .writeStream
      .outputMode(OutputMode.Append) //Only mode for ES
      .format("org.elasticsearch.spark.sql") //es
      .queryName("ElasticSink")
      .start("test/broadcast") //ES index
  }

} 
Example 13
Source File: KafkaSink.scala    From Spark-Structured-Streaming-Examples   with Apache License 2.0 5 votes vote down vote up
package kafka

import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{struct, to_json, _}
import _root_.log.LazyLogger
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.types.{StringType, _}
import radio.{SimpleSongAggregation, SimpleSongAggregationKafka}
import spark.SparkHelper

object KafkaSink extends LazyLogger {
  private val spark = SparkHelper.getSparkSession()

  import spark.implicits._

  def writeStream(staticInputDS: Dataset[SimpleSongAggregation]) : StreamingQuery = {
    log.warn("Writing to Kafka")
    staticInputDS
      .select(to_json(struct($"*")).cast(StringType).alias("value"))
      .writeStream
      .outputMode("update")
      .format("kafka")
      .option("kafka.bootstrap.servers", KafkaService.bootstrapServers)
      .queryName("Kafka - Count number of broadcasts for a title/artist by radio")
      .option("topic", "test")
      .start()
  }

  
  def debugStream(staticKafkaInputDS: Dataset[SimpleSongAggregationKafka]) = {
    staticKafkaInputDS
      .writeStream
      .queryName("Debug Stream Kafka")
      .format("console")
      .start()
  }
} 
Example 14
Source File: SparkStreamletContextImpl.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.spark.kafka

import java.io.File

import com.typesafe.config.Config
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.{ ExpressionEncoder, RowEncoder }
import org.apache.spark.sql.streaming.{ OutputMode, StreamingQuery }
import cloudflow.spark.SparkStreamletContext
import cloudflow.spark.avro.{ SparkAvroDecoder, SparkAvroEncoder }
import cloudflow.spark.sql.SQLImplicits._
import cloudflow.streamlets._

import scala.reflect.runtime.universe._

class SparkStreamletContextImpl(
    private[cloudflow] override val streamletDefinition: StreamletDefinition,
    session: SparkSession,
    override val config: Config
) extends SparkStreamletContext(streamletDefinition, session) {

  val storageDir           = config.getString("storage.mountPath")
  val maxOffsetsPerTrigger = config.getLong("cloudflow.spark.read.options.max-offsets-per-trigger")
  def readStream[In](inPort: CodecInlet[In])(implicit encoder: Encoder[In], typeTag: TypeTag[In]): Dataset[In] = {

    implicit val inRowEncoder: ExpressionEncoder[Row] = RowEncoder(encoder.schema)
    val schema                                        = inPort.schemaAsString
    val topic                                         = findTopicForPort(inPort)
    val srcTopic                                      = topic.name
    val brokers                                       = topic.bootstrapServers.getOrElse(internalKafkaBootstrapServers)
    val src: DataFrame = session.readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", brokers)
      .options(kafkaConsumerMap(topic))
      .option("maxOffsetsPerTrigger", maxOffsetsPerTrigger)
      .option("subscribe", srcTopic)
      // Allow restart of stateful streamlets that may have been offline for longer than the kafka retention period.
      // This setting may result in data loss in some cases but allows for continuity of the runtime
      .option("failOnDataLoss", false)
      .option("startingOffsets", "earliest")
      .load()

    val rawDataset = src.select($"value").as[Array[Byte]]

    val dataframe: Dataset[Row] = rawDataset.mapPartitions { iter ⇒
      val avroDecoder = new SparkAvroDecoder[In](schema)
      iter.map(avroDecoder.decode)
    }(inRowEncoder)

    dataframe.as[In]
  }

  def kafkaConsumerMap(topic: Topic) = topic.kafkaConsumerProperties.map {
    case (key, value) => s"kafka.$key" -> value
  }
  def kafkaProducerMap(topic: Topic) = topic.kafkaProducerProperties.map {
    case (key, value) => s"kafka.$key" -> value
  }

  def writeStream[Out](stream: Dataset[Out], outPort: CodecOutlet[Out], outputMode: OutputMode)(implicit encoder: Encoder[Out],
                                                                                                typeTag: TypeTag[Out]): StreamingQuery = {

    val avroEncoder   = new SparkAvroEncoder[Out](outPort.schemaAsString)
    val encodedStream = avroEncoder.encodeWithKey(stream, outPort.partitioner)

    val topic     = findTopicForPort(outPort)
    val destTopic = topic.name
    val brokers   = topic.bootstrapServers.getOrElse(internalKafkaBootstrapServers)

    // metadata checkpoint directory on mount
    val checkpointLocation = checkpointDir(outPort.name)
    val queryName          = s"$streamletRef.$outPort"

    encodedStream.writeStream
      .outputMode(outputMode)
      .format("kafka")
      .queryName(queryName)
      .option("kafka.bootstrap.servers", brokers)
      .options(kafkaProducerMap(topic))
      .option("topic", destTopic)
      .option("checkpointLocation", checkpointLocation)
      .start()
  }

  def checkpointDir(dirName: String): String = {
    val baseCheckpointDir = new File(storageDir, streamletRef)
    val dir               = new File(baseCheckpointDir, dirName)
    if (!dir.exists()) {
      val created = dir.mkdirs()
      require(created, s"Could not create checkpoint directory: $dir")
    }
    dir.getAbsolutePath
  }
} 
Example 15
Source File: TestSparkStreamletContext.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.spark
package testkit

import java.nio.file.attribute.FileAttribute

import com.typesafe.config._

import scala.reflect.runtime.universe._
import scala.concurrent.duration._
import org.apache.spark.sql.{ Dataset, Encoder, SparkSession }
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.{ OutputMode, StreamingQuery, Trigger }
import cloudflow.streamlets._
import org.apache.spark.sql.catalyst.InternalRow


class TestSparkStreamletContext(override val streamletRef: String,
                                session: SparkSession,
                                inletTaps: Seq[SparkInletTap[_]],
                                outletTaps: Seq[SparkOutletTap[_]],
                                override val config: Config = ConfigFactory.empty)
    extends SparkStreamletContext(StreamletDefinition("appId", "appVersion", streamletRef, "streamletClass", List(), List(), config),
                                  session) {
  val ProcessingTimeInterval = 1500.milliseconds
  override def readStream[In](inPort: CodecInlet[In])(implicit encoder: Encoder[In], typeTag: TypeTag[In]): Dataset[In] =
    inletTaps
      .find(_.portName == inPort.name)
      .map(_.instream.asInstanceOf[MemoryStream[In]].toDF.as[In])
      .getOrElse(throw TestContextException(inPort.name, s"Bad test context, could not find source for inlet ${inPort.name}"))

  override def writeStream[Out](stream: Dataset[Out],
                                outPort: CodecOutlet[Out],
                                outputMode: OutputMode)(implicit encoder: Encoder[Out], typeTag: TypeTag[Out]): StreamingQuery = {
    // RateSource can only work with a microBatch query because it contains no data at time zero.
    // Trigger.Once requires data at start to work.
    val trigger = if (isRateSource(stream)) {
      Trigger.ProcessingTime(ProcessingTimeInterval)
    } else {
      Trigger.Once()
    }
    val streamingQuery = outletTaps
      .find(_.portName == outPort.name)
      .map { outletTap ⇒
        stream.writeStream
          .outputMode(outputMode)
          .format("memory")
          .trigger(trigger)
          .queryName(outletTap.queryName)
          .start()
      }
      .getOrElse(throw TestContextException(outPort.name, s"Bad test context, could not find destination for outlet ${outPort.name}"))
    streamingQuery
  }

  override def checkpointDir(dirName: String): String = {
    val fileAttibutes: Array[FileAttribute[_]] = Array()
    val tmpDir                                 = java.nio.file.Files.createTempDirectory("spark-test", fileAttibutes: _*)
    tmpDir.toFile.getAbsolutePath
  }

  private def isRateSource(stream: Dataset[_]): Boolean = {
    import org.apache.spark.sql.execution.command.ExplainCommand
    val explain = ExplainCommand(stream.queryExecution.logical, true)
    val res     = session.sessionState.executePlan(explain).executedPlan.executeCollect()
    res.exists((row: InternalRow) => row.getString(0).contains("org.apache.spark.sql.execution.streaming.sources.RateStreamProvider"))
  }

}

case class TestContextException(portName: String, msg: String) extends RuntimeException(msg) 
Example 16
Source File: StreamingTestHelper.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.hiveacid.streaming

import java.io.{File, IOException}
import java.util.UUID

import com.qubole.spark.hiveacid.TestHelper

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import org.scalatest.concurrent.TimeLimits
import org.scalatest.time.SpanSugar

class StreamingTestHelper extends TestHelper with TimeLimits  {

  import StreamingTestHelper._


  def runStreaming(tableName: String,
                   outputMode: OutputMode,
                   cols: Seq[String],
                   inputRange: Range,
                   options: List[(String, String)] = List.empty): Unit = {

    val inputData = MemoryStream[Int]
    val ds = inputData.toDS()

    val checkpointDir = createCheckpointDir(namePrefix = "stream.checkpoint").getCanonicalPath

    var query: StreamingQuery = null

    try {
      // Starting streaming query
      val writerDf =
        ds.map(i => (i*100, i*10, i))
          .toDF(cols:_*)
          .writeStream
          .format("HiveAcid")
          .option("table", tableName)
          .outputMode(outputMode)
          .option("checkpointLocation", checkpointDir)
          //.start()

      query = options.map { option =>
        writerDf.option(option._1, option._2)
      }.lastOption.getOrElse(writerDf).start()

      // Adding data for streaming query
      inputData.addData(inputRange)
      failAfter(STREAMING_TIMEOUT) {
        query.processAllAvailable()
      }
    } finally {
      if (query != null) {
        // Terminating streaming query
        query.stop()
        deleteCheckpointDir(checkpointDir)
      }
    }
  }

  def deleteCheckpointDir(fileStr: String): Unit = {
    val file = new File(fileStr)
    if (file != null) {
      JavaUtils.deleteRecursively(file)
    }
  }

  def createCheckpointDir(root: String = System.getProperty("java.io.tmpdir"),
                          namePrefix: String = "spark"): File = {

    var attempts = 0
    val maxAttempts = MAX_DIR_CREATION_ATTEMPTS
    var dir: File = null
    while (dir == null) {
      attempts += 1
      if (attempts > maxAttempts) {
        throw new IOException("Failed to create a temp directory (under " + root + ") after " +
          maxAttempts + " attempts!")
      }
      try {
        dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString)
        if (dir.exists() || !dir.mkdirs()) {
          dir = null
        }
      } catch { case e: SecurityException => dir = null; }
    }
    dir.getCanonicalFile
  }

}

object StreamingTestHelper extends TestHelper with SpanSugar  {

  val MAX_DIR_CREATION_ATTEMPTS = 10
  val STREAMING_TIMEOUT = 60.seconds

}