org.apache.spark.rpc.RpcEndpointRef Scala Examples

The following examples show how to use org.apache.spark.rpc.RpcEndpointRef. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: YarnRMClient.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import scala.collection.JavaConverters._

import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.webapp.util.WebAppUtils

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils


  def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
    val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt)
    val yarnMaxAttempts = yarnConf.getInt(
      YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
    sparkMaxAttempts match {
      case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
      case None => yarnMaxAttempts
    }
  }

} 
Example 2
Source File: RpcUtils.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.SparkConf
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout}

private[spark] object RpcUtils {

  
  def maxMessageSizeBytes(conf: SparkConf): Int = {
    val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128)
    if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) {
      throw new IllegalArgumentException(
        s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB")
    }
    maxSizeInMB * 1024 * 1024
  }
} 
Example 3
Source File: MasterWebUISuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 4
Source File: ReceiverInfo.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rpc.RpcEndpointRef


@DeveloperApi
case class ReceiverInfo(
    streamId: Int,
    name: String,
    private[streaming] val endpoint: RpcEndpointRef,
    active: Boolean,
    location: String,
    lastErrorMessage: String = "",
    lastError: String = "",
    lastErrorTime: Long = -1L
   ) {
} 
Example 5
Source File: InstantiateClass.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package sparkDemo

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.SparkEnv.{logDebug, logInfo}
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.RpcUtils




object InstantiateClass extends App with Logging {
  val conf: SparkConf = new SparkConf()
  def instantiateClass[T](className: String): T = {
    logInfo(s"className: ${className}")
    val cls = classForName(className)
    logInfo(s"cls: ${cls}")
    // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
    // SparkConf, then one taking no arguments
    //寻找一个构造函数,使用一个SparkConf和一个布尔值为isDriver的代码,然后需要一个参数Boolean的SparkConf构造函数
    //查找一个sparkconf构造函数,是否isDriver
    try {
      //classOf类强制类型转换SparkConf类,classOf[T]`等同于Java中的类文字`T.class`。
      logInfo(s"classOf[SparkConf]: ${classOf[SparkConf]}")

      val tset=cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
        .newInstance(conf, new java.lang.Boolean(true))
        //asInstanceOf强制类型[T]对象
        .asInstanceOf[T]
      logInfo(s"asInstanceOf[T]: ${tset.toString}")
      cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
        .newInstance(conf, new java.lang.Boolean(true))
        //asInstanceOf强制类型[T]对象
        .asInstanceOf[T]
    } catch {
      case _: NoSuchMethodException =>
        try {
          logInfo(s"asInstanceOf[T]: ${cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]}")
          cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]

        } catch {
          case _: NoSuchMethodException =>
            logInfo(s"1111111 asInstanceOf[T]: ${cls.getConstructor().newInstance()}")
            logInfo(s"asInstanceOf[T]: ${cls.getConstructor().newInstance().asInstanceOf[T]}")

            cls.getConstructor().newInstance().asInstanceOf[T]
        }
    }
  }

  def classForName(className: String): Class[_] = {

    val classLoader = getContextOrSparkClassLoader
    logInfo(s"classLoader: ${classLoader}")
    Class.forName(className, true, getContextOrSparkClassLoader)
    // scalastyle:on classforname
  }
  def getContextOrSparkClassLoader: ClassLoader = {
    //Thread.currentThread().getContextClassLoader,可以获取当前线程的引用,getContextClassLoader用来获取线程的上下文类加载器
    val ContextClassLoader=Thread.currentThread().getContextClassLoader

    logInfo(s"ContextClassLoader: ${ContextClassLoader}")
    //Thread.currentThread().getContextClassLoader,可以获取当前线程的引用,getContextClassLoader用来获取线程的上下文类加载器
  Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
}
  def getSparkClassLoader: ClassLoader ={
    logInfo(s"getClass.getClassLoader: ${ getClass.getClassLoader}")
    getClass.getClassLoader
  }

  def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {
    instantiateClass[T](conf.get(propertyName, defaultClassName))
  }
  val serializer = instantiateClassFromConf[Serializer](
    "spark.serializer", "org.apache.spark.serializer.JavaSerializer")
  logInfo(s"Using serializer: ${serializer.getClass}")
  println("====="+serializer.getClass)


} 
Example 6
Source File: ReceiverTrackingInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.streaming.scheduler.ReceiverState._

private[streaming] case class ReceiverErrorInfo(
    lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)


private[streaming] case class ReceiverTrackingInfo(
    receiverId: Int,
    state: ReceiverState,
    scheduledExecutors: Option[Seq[String]],
    runningExecutor: Option[String],
    name: Option[String] = None,
    endpoint: Option[RpcEndpointRef] = None,
    errorInfo: Option[ReceiverErrorInfo] = None) {

  def toReceiverInfo: ReceiverInfo = ReceiverInfo(
    receiverId,
    name.getOrElse(""),
    state == ReceiverState.ACTIVE,
    location = runningExecutor.getOrElse(""),
    lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
    lastError = errorInfo.map(_.lastError).getOrElse(""),
    lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
  )
} 
Example 7
Source File: ReceiverInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rpc.RpcEndpointRef


@DeveloperApi
case class ReceiverInfo(
    streamId: Int,
    name: String,
    active: Boolean,
    location: String,
    lastErrorMessage: String = "",
    lastError: String = "",
    lastErrorTime: Long = -1L
   ) {
} 
Example 8
Source File: ApplicationInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import java.util.Date

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.deploy.ApplicationDescription
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class ApplicationInfo(
  val startTime: Long,
  val id: String,
  val desc: ApplicationDescription,
  val submitDate: Date,
  val driver: RpcEndpointRef,
  defaultCores: Int)
    extends Serializable {
  //枚举类型赋值
  @transient var state: ApplicationState.Value = _
  @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _
  @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _
  @transient var coresGranted: Int = _
  @transient var endTime: Long = _
  @transient var appSource: ApplicationSource = _

  // A cap on the number of executors this application can have at any given time.
  //执行者的数量这个应用程序可以在任何给定的时间
  // By default, this is infinite. Only after the first allocation request is issued by the
  // application will this be set to a finite value. This is used for dynamic allocation.
  //默认情况下,这是无限的,只有在应用程序发出第一个分配请求之后,这将被设置为有限的值,这用于动态分配
  @transient private[master] var executorLimit: Int = _

  @transient private var nextExecutorId: Int = _

  init() //初始化方法

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }
  
  private[deploy] def getExecutorLimit: Int = executorLimit

  def duration: Long = {
    if (endTime != -1) {
      endTime - startTime
    } else {
      System.currentTimeMillis() - startTime
    }
  }

} 
Example 9
Source File: WorkerInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String, //Id标识
    val host: String,//Worker的IP
    val port: Int,//Worker的端口
    val cores: Int,//Worker节点的CPU
    val memory: Int,//Worker节点的内存
    val endpoint: RpcEndpointRef,
    val webUiPort: Int,
    val publicAddress: String)
  extends Serializable {

  Utils.checkHost(host, "Expected hostname")
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()//初始化数据

  def coresFree: Int = cores - coresUsed//
  def memoryFree: Int = memory - memoryUsed//

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE//活动状态
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def webUiAddress : String = {
    "http://" + this.publicAddress + ":" + this.webUiPort
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 10
Source File: WorkerInfo.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String,
    val host: String,
    val port: Int,
    val cores: Int,
    val memory: Int,
    val endpoint: RpcEndpointRef,
    val webUiAddress: String)
  extends Serializable {

  Utils.checkHost(host, "Expected hostname")
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()

  def coresFree: Int = cores - coresUsed
  def memoryFree: Int = memory - memoryUsed

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 11
Source File: ReceiverTrackingInfo.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation}
import org.apache.spark.streaming.scheduler.ReceiverState._

private[streaming] case class ReceiverErrorInfo(
    lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)


private[streaming] case class ReceiverTrackingInfo(
    receiverId: Int,
    state: ReceiverState,
    scheduledLocations: Option[Seq[TaskLocation]],
    runningExecutor: Option[ExecutorCacheTaskLocation],
    name: Option[String] = None,
    endpoint: Option[RpcEndpointRef] = None,
    errorInfo: Option[ReceiverErrorInfo] = None) {

  def toReceiverInfo: ReceiverInfo = ReceiverInfo(
    receiverId,
    name.getOrElse(""),
    state == ReceiverState.ACTIVE,
    location = runningExecutor.map(_.host).getOrElse(""),
    executorId = runningExecutor.map(_.executorId).getOrElse(""),
    lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
    lastError = errorInfo.map(_.lastError).getOrElse(""),
    lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
  )
} 
Example 12
Source File: WorkerInfo.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String,
    val host: String,
    val port: Int,
    val cores: Int,
    val memory: Int,
    val endpoint: RpcEndpointRef,
    val webUiAddress: String)
  extends Serializable {

  Utils.checkHost(host)
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()

  def coresFree: Int = cores - coresUsed
  def memoryFree: Int = memory - memoryUsed

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 13
Source File: RpcUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.SparkConf
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout}

private[spark] object RpcUtils {

  
  def maxMessageSizeBytes(conf: SparkConf): Int = {
    val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128)
    if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) {
      throw new IllegalArgumentException(
        s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB")
    }
    maxSizeInMB * 1024 * 1024
  }
} 
Example 14
Source File: MasterWebUISuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 15
Source File: ReceiverTrackingInfo.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation}
import org.apache.spark.streaming.scheduler.ReceiverState._

private[streaming] case class ReceiverErrorInfo(
    lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)


private[streaming] case class ReceiverTrackingInfo(
    receiverId: Int,
    state: ReceiverState,
    scheduledLocations: Option[Seq[TaskLocation]],
    runningExecutor: Option[ExecutorCacheTaskLocation],
    name: Option[String] = None,
    endpoint: Option[RpcEndpointRef] = None,
    errorInfo: Option[ReceiverErrorInfo] = None) {

  def toReceiverInfo: ReceiverInfo = ReceiverInfo(
    receiverId,
    name.getOrElse(""),
    state == ReceiverState.ACTIVE,
    location = runningExecutor.map(_.host).getOrElse(""),
    executorId = runningExecutor.map(_.executorId).getOrElse(""),
    lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
    lastError = errorInfo.map(_.lastError).getOrElse(""),
    lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
  )
} 
Example 16
Source File: ReceiverInfo.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rpc.RpcEndpointRef


@DeveloperApi
case class ReceiverInfo(
    streamId: Int,
    name: String,
    active: Boolean,
    location: String,
    executorId: String,
    lastErrorMessage: String = "",
    lastError: String = "",
    lastErrorTime: Long = -1L
   ) {
} 
Example 17
Source File: WorkerInfo.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String,
    val host: String,
    val port: Int,
    val cores: Int,
    val memory: Int,
    val endpoint: RpcEndpointRef,
    val webUiPort: Int,
    val publicAddress: String)
  extends Serializable {

  Utils.checkHost(host, "Expected hostname")
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()

  def coresFree: Int = cores - coresUsed
  def memoryFree: Int = memory - memoryUsed

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def webUiAddress : String = {
    "http://" + this.publicAddress + ":" + this.webUiPort
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 18
Source File: OapRpcManagerMaster.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.oap.rpc

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.sql.oap.listener.SparkListenerCustomInfoUpdate
import org.apache.spark.sql.oap.rpc.OapMessages._


private[spark] class OapRpcManagerMaster(oapRpcManagerMasterEndpoint: OapRpcManagerMasterEndpoint)
  extends OapRpcManager with Logging {

  private def sendOneWayMessageToExecutors(message: OapMessage): Unit = {
    oapRpcManagerMasterEndpoint.rpcEndpointRefByExecutor.foreach {
      case (_, slaveEndpoint) => slaveEndpoint.send(message)
    }
  }

  override private[spark] def send(message: OapMessage): Unit = {
    sendOneWayMessageToExecutors(message)
  }
}

private[spark] object OapRpcManagerMaster {
  val DRIVER_ENDPOINT_NAME = "OapRpcManagerMaster"
}

private[spark] class OapRpcManagerMasterEndpoint(
    override val rpcEnv: RpcEnv,
    listenerBus: LiveListenerBus)
  extends ThreadSafeRpcEndpoint with Logging {

  // Mapping from executor ID to RpcEndpointRef.
  private[rpc] val rpcEndpointRefByExecutor = new mutable.HashMap[String, RpcEndpointRef]

  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RegisterOapRpcManager(executorId, slaveEndpoint) =>
      context.reply(handleRegistration(executorId, slaveEndpoint))
    case _ =>
  }

  override def receive: PartialFunction[Any, Unit] = {
    case heartbeat: Heartbeat => handleHeartbeat(heartbeat)
    case message: OapMessage => handleNormalOapMessage(message)
    case _ =>
  }

  private def handleRegistration(executorId: String, ref: RpcEndpointRef): Boolean = {
    rpcEndpointRefByExecutor += ((executorId, ref))
    true
  }

  private def handleNormalOapMessage(message: OapMessage) = message match {
    case _: Heartbeat => throw new IllegalArgumentException(
      "This is only to deal with non-heartbeat messages")
    case DummyMessage(id, someContent) =>
      val c = this.getClass.getMethods
      logWarning(s"Dummy message received on Driver with id: $id, content: $someContent")
    case _ =>
  }

  private def handleHeartbeat(heartbeat: Heartbeat) = heartbeat match {
    case FiberCacheHeartbeat(executorId, blockManagerId, content) =>
      listenerBus.post(SparkListenerCustomInfoUpdate(
        blockManagerId.host, executorId, "OapFiberCacheHeartBeatMessager", content))
    case FiberCacheMetricsHeartbeat(executorId, blockManagerId, content) =>
      listenerBus.post(SparkListenerCustomInfoUpdate(
        blockManagerId.host, executorId, "FiberCacheManagerMessager", content))
    case _ =>
  }
} 
Example 19
Source File: YarnRMClient.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import java.util.{List => JList}

import scala.collection.JavaConverters._
import scala.util.Try

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.webapp.util.WebAppUtils

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils


  def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
    val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt)
    val yarnMaxAttempts = yarnConf.getInt(
      YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
    val retval: Int = sparkMaxAttempts match {
      case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
      case None => yarnMaxAttempts
    }

    retval
  }

} 
Example 20
Source File: ReceiverTrackingInfo.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation}
import org.apache.spark.streaming.scheduler.ReceiverState._

private[streaming] case class ReceiverErrorInfo(
    lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)


private[streaming] case class ReceiverTrackingInfo(
    receiverId: Int,
    state: ReceiverState,
    scheduledLocations: Option[Seq[TaskLocation]],
    runningExecutor: Option[ExecutorCacheTaskLocation],
    name: Option[String] = None,
    endpoint: Option[RpcEndpointRef] = None,
    errorInfo: Option[ReceiverErrorInfo] = None) {

  def toReceiverInfo: ReceiverInfo = ReceiverInfo(
    receiverId,
    name.getOrElse(""),
    state == ReceiverState.ACTIVE,
    location = runningExecutor.map(_.host).getOrElse(""),
    executorId = runningExecutor.map(_.executorId).getOrElse(""),
    lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
    lastError = errorInfo.map(_.lastError).getOrElse(""),
    lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
  )
} 
Example 21
Source File: WorkerInfo.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String,
    val host: String,
    val port: Int,
    val cores: Int,
    val memory: Int,
    val endpoint: RpcEndpointRef,
    val webUiAddress: String)
  extends Serializable {

  Utils.checkHost(host, "Expected hostname")
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()

  def coresFree: Int = cores - coresUsed
  def memoryFree: Int = memory - memoryUsed

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 22
Source File: LocalSchedulerBackend.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.local

import java.io.File
import java.net.URL
import java.nio.ByteBuffer

import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.ExecutorInfo

private case class ReviveOffers()

private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)

private case class KillTask(taskId: Long, interruptThread: Boolean)

private case class StopExecutor()


  def getUserClasspath(conf: SparkConf): Seq[URL] = {
    val userClassPathStr = conf.getOption("spark.executor.extraClassPath")
    userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL)
  }

  launcherBackend.connect()

  override def start() {
    val rpcEnv = SparkEnv.get.rpcEnv
    val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores)
    localEndpoint = rpcEnv.setupEndpoint("LocalSchedulerBackendEndpoint", executorEndpoint)
    listenerBus.post(SparkListenerExecutorAdded(
      System.currentTimeMillis,
      executorEndpoint.localExecutorId,
      new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty)))
    launcherBackend.setAppId(appId)
    launcherBackend.setState(SparkAppHandle.State.RUNNING)
  }

  override def stop() {
    stop(SparkAppHandle.State.FINISHED)
  }

  override def reviveOffers() {
    localEndpoint.send(ReviveOffers)
  }

  override def defaultParallelism(): Int =
    scheduler.conf.getInt("spark.default.parallelism", totalCores)

  override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
    localEndpoint.send(KillTask(taskId, interruptThread))
  }

  override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
    localEndpoint.send(StatusUpdate(taskId, state, serializedData))
  }

  override def applicationId(): String = appId

  private def stop(finalState: SparkAppHandle.State): Unit = {
    localEndpoint.ask(StopExecutor)
    try {
      launcherBackend.setState(finalState)
    } finally {
      launcherBackend.close()
    }
  }

} 
Example 23
Source File: RpcUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.SparkConf
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout}

private[spark] object RpcUtils {

  
  def maxMessageSizeBytes(conf: SparkConf): Int = {
    val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128)
    if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) {
      throw new IllegalArgumentException(
        s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB")
    }
    maxSizeInMB * 1024 * 1024
  }
} 
Example 24
Source File: MasterWebUISuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 25
Source File: StateStoreCoordinator.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import java.util.UUID

import scala.collection.mutable

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.util.RpcUtils


private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
    extends ThreadSafeRpcEndpoint with Logging {
  private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation]

  override def receive: PartialFunction[Any, Unit] = {
    case ReportActiveInstance(id, host, executorId) =>
      logDebug(s"Reported state store $id is active at $executorId")
      instances.put(id, ExecutorCacheTaskLocation(host, executorId))
  }

  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case VerifyIfInstanceActive(id, execId) =>
      val response = instances.get(id) match {
        case Some(location) => location.executorId == execId
        case None => false
      }
      logDebug(s"Verified that state store $id is active: $response")
      context.reply(response)

    case GetLocation(id) =>
      val executorId = instances.get(id).map(_.toString)
      logDebug(s"Got location of the state store $id: $executorId")
      context.reply(executorId)

    case DeactivateInstances(runId) =>
      val storeIdsToRemove =
        instances.keys.filter(_.queryRunId == runId).toSeq
      instances --= storeIdsToRemove
      logDebug(s"Deactivating instances related to checkpoint location $runId: " +
        storeIdsToRemove.mkString(", "))
      context.reply(true)

    case StopCoordinator =>
      stop() // Stop before replying to ensure that endpoint name has been deregistered
      logInfo("StateStoreCoordinator stopped")
      context.reply(true)
  }
} 
Example 26
Source File: RPCContinuousShuffleWriter.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.continuous.shuffle

import scala.concurrent.Future
import scala.concurrent.duration.Duration

import org.apache.spark.Partitioner
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.util.ThreadUtils


class RPCContinuousShuffleWriter(
    writerId: Int,
    outputPartitioner: Partitioner,
    endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter {

  if (outputPartitioner.numPartitions != 1) {
    throw new IllegalArgumentException("multiple readers not yet supported")
  }

  if (outputPartitioner.numPartitions != endpoints.length) {
    throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " +
      s"not match endpoint count ${endpoints.length}")
  }

  def write(epoch: Iterator[UnsafeRow]): Unit = {
    while (epoch.hasNext) {
      val row = epoch.next()
      endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row))
    }

    val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq
    implicit val ec = ThreadUtils.sameThread
    ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf)
  }
} 
Example 27
Source File: StateStoreCoordinator.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import scala.collection.mutable

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.util.RpcUtils


private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
    extends ThreadSafeRpcEndpoint with Logging {
  private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation]

  override def receive: PartialFunction[Any, Unit] = {
    case ReportActiveInstance(id, host, executorId) =>
      logDebug(s"Reported state store $id is active at $executorId")
      instances.put(id, ExecutorCacheTaskLocation(host, executorId))
  }

  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case VerifyIfInstanceActive(id, execId) =>
      val response = instances.get(id) match {
        case Some(location) => location.executorId == execId
        case None => false
      }
      logDebug(s"Verified that state store $id is active: $response")
      context.reply(response)

    case GetLocation(id) =>
      val executorId = instances.get(id).map(_.toString)
      logDebug(s"Got location of the state store $id: $executorId")
      context.reply(executorId)

    case DeactivateInstances(checkpointLocation) =>
      val storeIdsToRemove =
        instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq
      instances --= storeIdsToRemove
      logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " +
        storeIdsToRemove.mkString(", "))
      context.reply(true)

    case StopCoordinator =>
      stop() // Stop before replying to ensure that endpoint name has been deregistered
      logInfo("StateStoreCoordinator stopped")
      context.reply(true)
  }
} 
Example 28
Source File: OapMessages.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.oap.rpc

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.storage.BlockManagerId

private[spark] object OapMessages {

  sealed trait OapMessage extends Serializable

  sealed trait ToOapRpcManagerSlave extends OapMessage
  sealed trait ToOapRpcManagerMaster extends OapMessage
  sealed trait Heartbeat extends ToOapRpcManagerMaster

  
  case class RegisterOapRpcManager(
      executorId: String, oapRpcManagerEndpoint: RpcEndpointRef) extends ToOapRpcManagerMaster
  case class DummyHeartbeat(someContent: String) extends Heartbeat
  case class FiberCacheHeartbeat(
      executorId: String, blockManagerId: BlockManagerId, content: String) extends Heartbeat
  case class FiberCacheMetricsHeartbeat(
      executorId: String, blockManagerId: BlockManagerId, content: String) extends Heartbeat
} 
Example 29
Source File: OapRpcManagerSlave.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.oap.rpc

import java.util.concurrent.TimeUnit

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.execution.datasources.oap.filecache.{CacheStats, FiberCacheManager}
import org.apache.spark.sql.internal.oap.OapConf
import org.apache.spark.sql.oap.adapter.RpcEndpointRefAdapter
import org.apache.spark.sql.oap.rpc.OapMessages._
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.{ThreadUtils, Utils}


private[spark] class OapRpcManagerSlave(
    rpcEnv: RpcEnv,
    val driverEndpoint: RpcEndpointRef,
    executorId: String,
    blockManager: BlockManager,
    fiberCacheManager: FiberCacheManager,
    conf: SparkConf) extends OapRpcManager {

  // Send OapHeartbeatMessage to Driver timed
  private val oapHeartbeater =
    ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater")

  private val slaveEndpoint = rpcEnv.setupEndpoint(
    s"OapRpcManagerSlave_$executorId", new OapRpcManagerSlaveEndpoint(rpcEnv, fiberCacheManager))

  initialize()
  startOapHeartbeater()

  protected def heartbeatMessages: Array[() => Heartbeat] = {
    Array(
      () => FiberCacheHeartbeat(
        executorId, blockManager.blockManagerId, fiberCacheManager.status()),
      () => FiberCacheMetricsHeartbeat(executorId, blockManager.blockManagerId,
        CacheStats.status(fiberCacheManager.cacheStats, conf)))
  }

  private def initialize() = {
    RpcEndpointRefAdapter.askSync[Boolean](
      driverEndpoint, RegisterOapRpcManager(executorId, slaveEndpoint))
  }

  override private[spark] def send(message: OapMessage): Unit = {
    driverEndpoint.send(message)
  }

  private[sql] def startOapHeartbeater(): Unit = {

    def reportHeartbeat(): Unit = {
      // OapRpcManagerSlave is created in SparkEnv. Before we start the heartbeat, we need make
      // sure the SparkEnv has been created and the block manager has been initialized. We check
      // blockManagerId as it will be set after initialization.
      if (blockManager.blockManagerId != null) {
        heartbeatMessages.map(_.apply()).foreach(send)
      }
    }

    val intervalMs = conf.getTimeAsMs(
      OapConf.OAP_HEARTBEAT_INTERVAL.key, OapConf.OAP_HEARTBEAT_INTERVAL.defaultValue.get)

    // Wait a random interval so the heartbeats don't end up in sync
    val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int]

    val heartbeatTask = new Runnable() {
      override def run(): Unit = Utils.logUncaughtExceptions(reportHeartbeat())
    }
    oapHeartbeater.scheduleAtFixedRate(
      heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
  }

  override private[spark] def stop(): Unit = {
    oapHeartbeater.shutdown()
  }
}

private[spark] class OapRpcManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv, fiberCacheManager: FiberCacheManager)
  extends ThreadSafeRpcEndpoint with Logging {

  override def receive: PartialFunction[Any, Unit] = {
    case message: OapMessage => handleOapMessage(message)
    case _ =>
  }

  private def handleOapMessage(message: OapMessage): Unit = message match {
    case CacheDrop(indexName) => fiberCacheManager.releaseIndexCache(indexName)
    case _ =>
  }
} 
Example 30
Source File: ReceiverTrackingInfo.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation}
import org.apache.spark.streaming.scheduler.ReceiverState._

private[streaming] case class ReceiverErrorInfo(
    lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)


private[streaming] case class ReceiverTrackingInfo(
    receiverId: Int,
    state: ReceiverState,
    scheduledLocations: Option[Seq[TaskLocation]],
    runningExecutor: Option[ExecutorCacheTaskLocation],
    name: Option[String] = None,
    endpoint: Option[RpcEndpointRef] = None,
    errorInfo: Option[ReceiverErrorInfo] = None) {

  def toReceiverInfo: ReceiverInfo = ReceiverInfo(
    receiverId,
    name.getOrElse(""),
    state == ReceiverState.ACTIVE,
    location = runningExecutor.map(_.host).getOrElse(""),
    executorId = runningExecutor.map(_.executorId).getOrElse(""),
    lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
    lastError = errorInfo.map(_.lastError).getOrElse(""),
    lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
  )
} 
Example 31
Source File: WorkerInfo.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String,
    val host: String,
    val port: Int,
    val cores: Int,
    val memory: Int,
    val endpoint: RpcEndpointRef,
    val webUiAddress: String)
  extends Serializable {

  Utils.checkHost(host, "Expected hostname")
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()

  def coresFree: Int = cores - coresUsed
  def memoryFree: Int = memory - memoryUsed

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 32
Source File: RpcUtils.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.SparkConf
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout}

private[spark] object RpcUtils {

  
  def maxMessageSizeBytes(conf: SparkConf): Int = {
    val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128)
    if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) {
      throw new IllegalArgumentException(
        s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB")
    }
    maxSizeInMB * 1024 * 1024
  }
} 
Example 33
Source File: MasterWebUISuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 34
Source File: ReceiverTrackingInfo.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation}
import org.apache.spark.streaming.scheduler.ReceiverState._

private[streaming] case class ReceiverErrorInfo(
    lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)


private[streaming] case class ReceiverTrackingInfo(
    receiverId: Int,
    state: ReceiverState,
    scheduledLocations: Option[Seq[TaskLocation]],
    runningExecutor: Option[ExecutorCacheTaskLocation],
    name: Option[String] = None,
    endpoint: Option[RpcEndpointRef] = None,
    errorInfo: Option[ReceiverErrorInfo] = None) {

  def toReceiverInfo: ReceiverInfo = ReceiverInfo(
    receiverId,
    name.getOrElse(""),
    state == ReceiverState.ACTIVE,
    location = runningExecutor.map(_.host).getOrElse(""),
    executorId = runningExecutor.map(_.executorId).getOrElse(""),
    lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
    lastError = errorInfo.map(_.lastError).getOrElse(""),
    lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
  )
}