org.apache.spark.sql.streaming.GroupState Scala Examples

The following examples show how to use org.apache.spark.sql.streaming.GroupState. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: GroupsWithStateFunction.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.operations.stateful

import com.phylosoft.spark.learning.sql.streaming.domain.Model.{Event, SessionInfo, SessionUpdate}
import org.apache.spark.sql.streaming.GroupState

trait GroupsWithStateFunction {

  private[stateful] val sessionUpdate = (sessionId: String,
                                         events: Iterator[Event],
                                         state: GroupState[SessionInfo]) => {
    // If timed out, then remove session and send final update
    if (state.hasTimedOut) {
      val finalUpdate =
        SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true)
      state.remove()
      finalUpdate
    } else {
      // Update start and end timestamps in session
      val timestamps = events.map(_.timestamp.getTime).toSeq
      val updatedSession = if (state.exists) {
        val oldSession = state.get
        SessionInfo(
          oldSession.numEvents + timestamps.size,
          oldSession.startTimestampMs,
          math.max(oldSession.endTimestampMs, timestamps.max))
      } else {
        SessionInfo(timestamps.size, timestamps.min, timestamps.max)
      }
      state.update(updatedSession)

      // Set timeout such that the session will be expired if no data received for 10 seconds
      state.setTimeoutDuration("5 seconds")
      SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false)
    }
  }

} 
Example 2
Source File: MultiStreamHandler.scala    From structured-streaming-application   with Apache License 2.0 5 votes vote down vote up
package knolx.spark

import knolx.Config._
import knolx.KnolXLogger
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{Encoders, SparkSession}


case class CurrentPowerConsumption(kwh: Double)

case class PowerConsumptionStatus(numOfReadings: Long, total: Double, avg: Double, status: String) {
  def compute(newReadings: List[Double]) = {
    val newTotal = newReadings.sum + total
    val newNumOfReadings = numOfReadings + newReadings.size
    val newAvg = newTotal / newNumOfReadings.toDouble

    PowerConsumptionStatus(newNumOfReadings, newTotal, newAvg, "ON")
  }
}

object MultiStreamHandler extends App with KnolXLogger {
  info("Creating Spark Session")
  val spark = SparkSession.builder().master(sparkMaster).appName(sparkAppName).getOrCreate()
  spark.sparkContext.setLogLevel("WARN")

  val updateStateFunc =
    (deviceId: String, newReadings: Iterator[(String, CurrentPowerConsumption)], state: GroupState[PowerConsumptionStatus]) => {
      val data = newReadings.toList.map { case(_, reading) => reading }.map(_.kwh)

      lazy val initialPowerConsumptionStatus = PowerConsumptionStatus(0L, 0D, 0D, "OFF")
      val currentState = state.getOption.fold(initialPowerConsumptionStatus.compute(data))(_.compute(data))

      val currentStatus =
        if(state.hasTimedOut) {
          // If we do not receive any reading, for a device, we will assume that it is OFF.
          currentState.copy(status = "OFF")
        } else {
          state.setTimeoutDuration("10 seconds")
          currentState
        }

      state.update(currentStatus)
      (deviceId, currentStatus)
    }

  info("Creating Streaming DF...")
  val dataStream =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", topic)
      .option("failOnDataLoss", false)
      .option("includeTimestamp", true)
      .load()

  info("Writing data to Console...")
  import spark.implicits._

  implicit val currentPowerConsumptionEncoder = Encoders.kryo[CurrentPowerConsumption]
  implicit val powerConsumptionStatusEncoder = Encoders.kryo[PowerConsumptionStatus]

  val query =
    dataStream
      .select(col("key").cast(StringType).as("key"), col("value").cast(StringType).as("value"))
      .as[(String, String)]
      .map { case(deviceId, unit) =>
        (deviceId, CurrentPowerConsumption(Option(unit).fold(0D)(_.toDouble)))
      }
      .groupByKey { case(deviceId, _) => deviceId }
      .mapGroupsWithState[PowerConsumptionStatus, (String, PowerConsumptionStatus)](GroupStateTimeout.ProcessingTimeTimeout())(updateStateFunc)
      .toDF("deviceId", "current_status")
      .writeStream
      .format("console")
      .option("truncate", false)
      .outputMode(OutputMode.Update())
      .option("checkpointLocation", checkPointDir)
      .start()

  info("Waiting for the query to terminate...")
  query.awaitTermination()
  query.stop()
} 
Example 3
Source File: CountingInAStreamMapWithState.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.streaming.structured

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}

object CountingInAStreamMapWithState {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args:Array[String]): Unit = {
    val host = args(0)
    val port = args(1)
    val checkpointFolder = args(2)

    val isLocal = true

    val sparkSession = if (isLocal) {
      SparkSession.builder
        .master("local")
        .appName("my-spark-app")
        .config("spark.some.config.option", "config-value")
        .config("spark.driver.host","127.0.0.1")
        .config("spark.sql.parquet.compression.codec", "gzip")
        .master("local[3]")
        .getOrCreate()
    } else {
      SparkSession.builder
        .appName("my-spark-app")
        .config("spark.some.config.option", "config-value")
        .master("local[3]")
        .getOrCreate()
    }

    import sparkSession.implicits._

    val socketLines = sparkSession.readStream
      .format("socket")
      .option("host", host)
      .option("port", port)
      .load()

    val messageDs = socketLines.as[String].
      flatMap(line => line.toLowerCase().split(" ")).
      map(word => WordCountEvent(word, 1))

    // Generate running word count
    val wordCounts = messageDs.groupByKey(tuple => tuple.word).
      mapGroupsWithState[WordCountInMemory, WordCountReturn](GroupStateTimeout.ProcessingTimeTimeout) {

      case (word: String, events: Iterator[WordCountEvent], state: GroupState[WordCountInMemory]) =>
        var newCount = if (state.exists) state.get.countOfWord else 0

        events.foreach(tuple => {
          newCount += tuple.countOfWord
        })

        state.update(WordCountInMemory(newCount))

        WordCountReturn(word, newCount)
    }

    // Start running the query that prints the running counts to the console
    val query = wordCounts.writeStream
      .outputMode("update")
      .format("console")
      .start()

    query.awaitTermination()
  }
}

case class WordCountEvent(word:String, countOfWord:Int) extends Serializable {

}

case class WordCountInMemory(countOfWord: Int) extends Serializable {
}

case class WordCountReturn(word:String, countOfWord:Int) extends Serializable {

} 
Example 4
Source File: MapGroupsWithState.scala    From Spark-Structured-Streaming-Examples   with Apache License 2.0 5 votes vote down vote up
package mapGroupsWithState

import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{struct, to_json, _}
import _root_.log.LazyLogger
import org.apache.spark.sql.types.StringType
import spark.SparkHelper
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
import radio.{ArtistAggregationState, SimpleSongAggregation, SimpleSongAggregationKafka}

object MapGroupsWithState extends LazyLogger {
  private val spark = SparkHelper.getSparkSession()

  import spark.implicits._


  def updateArtistStateWithEvent(state: ArtistAggregationState, artistCount : SimpleSongAggregation) = {
    log.warn("MapGroupsWithState - updateArtistStateWithEvent")
    if(state.artist == artistCount.artist) {
      ArtistAggregationState(state.artist, state.count + artistCount.count)
    } else {
      state
    }
  }

  def updateAcrossEvents(artist:String,
                         inputs: Iterator[SimpleSongAggregation],
                         oldState: GroupState[ArtistAggregationState]): ArtistAggregationState = {

    var state: ArtistAggregationState = if (oldState.exists)
      oldState.get
    else
      ArtistAggregationState(artist, 1L)

    // for every rows, let's count by artist the number of broadcast, instead of counting by artist, title and radio
    for (input <- inputs) {
      state = updateArtistStateWithEvent(state, input)
      oldState.update(state)
    }

    state
  }


  
  def write(ds: Dataset[SimpleSongAggregationKafka] ) = {
    ds.select($"radioCount.title", $"radioCount.artist", $"radioCount.radio", $"radioCount.count")
      .as[SimpleSongAggregation]
      .groupByKey(_.artist)
      .mapGroupsWithState(GroupStateTimeout.NoTimeout)(updateAcrossEvents) //we can control what should be done with the state when no update is received after a timeout.
      .writeStream
      .outputMode(OutputMode.Update())
      .format("console")
      .queryName("mapGroupsWithState - counting artist broadcast")
      .start()
  }
}