org.apache.spark.rpc.ThreadSafeRpcEndpoint Scala Examples
Example 1
Source File: StateStoreCoordinator.scala From drizzle-spark with Apache License 2.0
package org.apache.spark.sql.execution.streaming.state import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.util.RpcUtils private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => logDebug(s"Reported state store $id is active at $executorId") instances.put(id, ExecutorCacheTaskLocation(host, executorId)) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case VerifyIfInstanceActive(id, execId) => val response = instances.get(id) match { case Some(location) => location.executorId == execId case None => false } logDebug(s"Verified that state store $id is active: $response") context.reply(response) case GetLocation(id) => val executorId = instances.get(id).map(_.toString) logDebug(s"Got location of the state store $id: $executorId") context.reply(executorId) case DeactivateInstances(checkpointLocation) => val storeIdsToRemove = instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq instances --= storeIdsToRemove logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + storeIdsToRemove.mkString(", ")) context.reply(true) case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped") context.reply(true) } }
Example 2
Source File: LocalSchedulerBackend.scala From drizzle-spark with Apache License 2.0
package org.apache.spark.scheduler.local import import import java.nio.ByteBuffer import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.internal.Logging import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() def getUserClasspath(conf: SparkConf): Seq[URL] = { val userClassPathStr = conf.getOption("spark.executor.extraClassPath") File(_).toURI.toURL) } launcherBackend.connect() override def start() { val rpcEnv = SparkEnv.get.rpcEnv val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) localEndpoint = rpcEnv.setupEndpoint("LocalSchedulerBackendEndpoint", executorEndpoint) System.currentTimeMillis, executorEndpoint.localExecutorId, new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) launcherBackend.setAppId(appId) launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def stop() { stop(SparkAppHandle.State.FINISHED) } override def reviveOffers() { localEndpoint.send(ReviveOffers) } override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId private def stop(finalState: SparkAppHandle.State): Unit = { localEndpoint.ask(StopExecutor) try { launcherBackend.setState(finalState) } finally { launcherBackend.close() } } }
Example 3
Source File: BlockManagerSlaveEndpoint.scala From drizzle-spark with Apache License 2.0
package import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{MapOutputTracker, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import import org.apache.spark.util.{ThreadUtils, Utils} private[storage] class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) case TriggerThreadDump => context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) context.sendFailure(t) } } override def onStop(): Unit = { asyncThreadPool.shutdownNow() } }
Example 4
Source File: StateStoreCoordinator.scala From XSQL with Apache License 2.0
package org.apache.spark.sql.execution.streaming.state import java.util.UUID import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.util.RpcUtils private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => logDebug(s"Reported state store $id is active at $executorId") instances.put(id, ExecutorCacheTaskLocation(host, executorId)) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case VerifyIfInstanceActive(id, execId) => val response = instances.get(id) match { case Some(location) => location.executorId == execId case None => false } logDebug(s"Verified that state store $id is active: $response") context.reply(response) case GetLocation(id) => val executorId = instances.get(id).map(_.toString) logDebug(s"Got location of the state store $id: $executorId") context.reply(executorId) case DeactivateInstances(runId) => val storeIdsToRemove = instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped") context.reply(true) } }
Example 5
Source File: RPCContinuousShuffleReader.scala From XSQL with Apache License 2.0
package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.NextIterator override def getNext(): UnsafeRow = { var nextRow: UnsafeRow = null while (!finished && nextRow == null) { completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { case null => // Try again if the poll didn't wait long enough to get a real result. // But we should be getting at least an epoch marker every checkpoint interval. val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { case (flag, idx) if !flag => idx } logWarning( s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") // The completion service guarantees this future will be available immediately. case future => future.get() match { case ReceiverRow(writerId, r) => // Start reading the next element in the queue we just took from. completion.submit(completionTask(writerId)) nextRow = r case ReceiverEpochMarker(writerId) => // Don't read any more from this queue. If all the writers have sent epoch markers, // the epoch is over; otherwise we need to loop again to poll from the remaining // writers. writerEpochMarkersReceived(writerId) = true if (writerEpochMarkersReceived.forall(_ == true)) { finished = true } } } } nextRow } override def close(): Unit = { executor.shutdownNow() } } } }
Example 6
Source File: ContinuousRecordEndpoint.scala From XSQL with Apache License 2.0
package org.apache.spark.sql.execution.streaming import org.apache.spark.SparkEnv import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset case class ContinuousRecordPartitionOffset(partitionId: Int, offset: Int) extends PartitionOffset case class GetRecord(offset: ContinuousRecordPartitionOffset) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetRecord(ContinuousRecordPartitionOffset(partitionId, offset)) => lock.synchronized { val bufOffset = offset - startOffsets(partitionId) val buf = buckets(partitionId) val record = if (buf.size <= bufOffset) None else Some(buf(bufOffset)) context.reply( } } }
Example 7
Source File: OapRpcManagerMaster.scala From OAP with Apache License 2.0
package org.apache.spark.sql.oap.rpc import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.sql.oap.listener.SparkListenerCustomInfoUpdate import org.apache.spark.sql.oap.rpc.OapMessages._ private[spark] class OapRpcManagerMaster(oapRpcManagerMasterEndpoint: OapRpcManagerMasterEndpoint) extends OapRpcManager with Logging { private def sendOneWayMessageToExecutors(message: OapMessage): Unit = { oapRpcManagerMasterEndpoint.rpcEndpointRefByExecutor.foreach { case (_, slaveEndpoint) => slaveEndpoint.send(message) } } override private[spark] def send(message: OapMessage): Unit = { sendOneWayMessageToExecutors(message) } } private[spark] object OapRpcManagerMaster { val DRIVER_ENDPOINT_NAME = "OapRpcManagerMaster" } private[spark] class OapRpcManagerMasterEndpoint( override val rpcEnv: RpcEnv, listenerBus: LiveListenerBus) extends ThreadSafeRpcEndpoint with Logging { // Mapping from executor ID to RpcEndpointRef. private[rpc] val rpcEndpointRefByExecutor = new mutable.HashMap[String, RpcEndpointRef] override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterOapRpcManager(executorId, slaveEndpoint) => context.reply(handleRegistration(executorId, slaveEndpoint)) case _ => } override def receive: PartialFunction[Any, Unit] = { case heartbeat: Heartbeat => handleHeartbeat(heartbeat) case message: OapMessage => handleNormalOapMessage(message) case _ => } private def handleRegistration(executorId: String, ref: RpcEndpointRef): Boolean = { rpcEndpointRefByExecutor += ((executorId, ref)) true } private def handleNormalOapMessage(message: OapMessage) = message match { case _: Heartbeat => throw new IllegalArgumentException( "This is only to deal with non-heartbeat messages") case DummyMessage(id, someContent) => val c = this.getClass.getMethods logWarning(s"Dummy message received on Driver with id: $id, content: $someContent") case _ => } private def handleHeartbeat(heartbeat: Heartbeat) = heartbeat match { case FiberCacheHeartbeat(executorId, blockManagerId, content) =>, executorId, "OapFiberCacheHeartBeatMessager", content)) case FiberCacheMetricsHeartbeat(executorId, blockManagerId, content) =>, executorId, "FiberCacheManagerMessager", content)) case _ => } }
Example 8
Source File: OapRpcManagerSlave.scala From OAP with Apache License 2.0
package org.apache.spark.sql.oap.rpc import java.util.concurrent.TimeUnit import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.execution.datasources.oap.filecache.{CacheStats, FiberCacheManager} import org.apache.spark.sql.internal.oap.OapConf import org.apache.spark.sql.oap.adapter.RpcEndpointRefAdapter import org.apache.spark.sql.oap.rpc.OapMessages._ import import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class OapRpcManagerSlave( rpcEnv: RpcEnv, val driverEndpoint: RpcEndpointRef, executorId: String, blockManager: BlockManager, fiberCacheManager: FiberCacheManager, conf: SparkConf) extends OapRpcManager { // Send OapHeartbeatMessage to Driver timed private val oapHeartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") private val slaveEndpoint = rpcEnv.setupEndpoint( s"OapRpcManagerSlave_$executorId", new OapRpcManagerSlaveEndpoint(rpcEnv, fiberCacheManager)) initialize() startOapHeartbeater() protected def heartbeatMessages: Array[() => Heartbeat] = { Array( () => FiberCacheHeartbeat( executorId, blockManager.blockManagerId, fiberCacheManager.status()), () => FiberCacheMetricsHeartbeat(executorId, blockManager.blockManagerId, CacheStats.status(fiberCacheManager.cacheStats, conf))) } private def initialize() = { RpcEndpointRefAdapter.askSync[Boolean]( driverEndpoint, RegisterOapRpcManager(executorId, slaveEndpoint)) } override private[spark] def send(message: OapMessage): Unit = { driverEndpoint.send(message) } private[sql] def startOapHeartbeater(): Unit = { def reportHeartbeat(): Unit = { // OapRpcManagerSlave is created in SparkEnv. Before we start the heartbeat, we need make // sure the SparkEnv has been created and the block manager has been initialized. We check // blockManagerId as it will be set after initialization. if (blockManager.blockManagerId != null) { } } val intervalMs = conf.getTimeAsMs( OapConf.OAP_HEARTBEAT_INTERVAL.key, OapConf.OAP_HEARTBEAT_INTERVAL.defaultValue.get) // Wait a random interval so the heartbeats don't end up in sync val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] val heartbeatTask = new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions(reportHeartbeat()) } oapHeartbeater.scheduleAtFixedRate( heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } override private[spark] def stop(): Unit = { oapHeartbeater.shutdown() } } private[spark] class OapRpcManagerSlaveEndpoint( override val rpcEnv: RpcEnv, fiberCacheManager: FiberCacheManager) extends ThreadSafeRpcEndpoint with Logging { override def receive: PartialFunction[Any, Unit] = { case message: OapMessage => handleOapMessage(message) case _ => } private def handleOapMessage(message: OapMessage): Unit = message match { case CacheDrop(indexName) => fiberCacheManager.releaseIndexCache(indexName) case _ => } }
Example 9
Source File: BlockManagerSlaveEndpoint.scala From sparkoscope with Apache License 2.0
package import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{MapOutputTracker, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import import org.apache.spark.util.{ThreadUtils, Utils} private[storage] class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) case TriggerThreadDump => context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) context.sendFailure(t) } } override def onStop(): Unit = { asyncThreadPool.shutdownNow() } }
Example 10
Source File: BlockManagerSlaveEndpoint.scala From multi-tenancy-spark with Apache License 2.0
package import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{MapOutputTracker, SparkEnv} private[storage] class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends ThreadSafeRpcEndpoint with Logging { private val user = Utils.getCurrentUserName private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } SparkEnv.get(user).shuffleManager.unregisterShuffle(shuffleId) } case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) case TriggerThreadDump => context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) context.sendFailure(t) } } override def onStop(): Unit = { asyncThreadPool.shutdownNow() } }
Example 11
Source File: BlockManagerSlaveEndpoint.scala From Spark-2.3.1 with Apache License 2.0
package import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{MapOutputTracker, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import import org.apache.spark.util.{ThreadUtils, Utils} private[storage] class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) case TriggerThreadDump => context.reply(Utils.getThreadDump()) case ReplicateBlock(blockId, replicas, maxReplicas) => context.reply(blockManager.replicateBlock(blockId, replicas.toSet, maxReplicas)) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.foreach { response => logDebug(s"Done $actionMessage, response is $response") context.reply(response) logDebug(s"Sent response: $response to ${context.senderAddress}") } future.failed.foreach { t => logError(s"Error in $actionMessage", t) context.sendFailure(t) } } override def onStop(): Unit = { asyncThreadPool.shutdownNow() } }
Example 12
Source File: BlockManagerSlaveEndpoint.scala From BigDatalog with Apache License 2.0
package import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import import org.apache.spark.util.{ThreadUtils, Utils} private[storage] class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) case TriggerThreadDump => context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) context.sendFailure(t) } } override def onStop(): Unit = { asyncThreadPool.shutdownNow() } }
Example 13
Source File: TestMetricsRpcEndpoint.scala From spark-metrics with BSD 3-Clause "New" or "Revised" License
package org.apache.spark.groupon.metrics.util import org.apache.spark.groupon.metrics._ import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv} class TestMetricsRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { val metricStore = new scala.collection.mutable.ArrayBuffer[MetricMessage]() override def receive: PartialFunction[Any, Unit] = { case msg: MetricMessage => { metricStore += msg } } def clear(): Unit = { metricStore.clear() } def getMetricNames: Seq[String] = { => metricMessage.metricName).toSeq } def getMetricValues: Seq[AnyVal] = { => metricMessage.value).toSeq } }