org.mockito.Mockito.mock Scala Examples
The following examples show how to use org.mockito.Mockito.mock.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: NaptimeModuleTest.scala From naptime with Apache License 2.0 | 8 votes |
package org.coursera.naptime import java.util.Date import javax.inject.Inject import akka.stream.Materializer import com.google.inject.Guice import com.google.inject.Stage import com.linkedin.data.schema.DataSchema import com.linkedin.data.schema.DataSchemaUtil import com.linkedin.data.schema.PrimitiveDataSchema import com.linkedin.data.schema.RecordDataSchema import org.coursera.naptime.model.KeyFormat import org.coursera.naptime.resources.TopLevelCollectionResource import org.coursera.naptime.router2.NaptimeRoutes import org.junit.Test import org.mockito.Mockito.mock import org.scalatest.junit.AssertionsForJUnit import play.api.libs.json.Json import play.api.libs.json.OFormat import scala.concurrent.ExecutionContext object NaptimeModuleTest { case class User(name: String, createdAt: Date) object User { implicit val oFormat: OFormat[User] = Json.format[User] } class MyResource(implicit val executionContext: ExecutionContext, val materializer: Materializer) extends TopLevelCollectionResource[String, User] { override implicit def resourceFormat: OFormat[User] = User.oFormat override def keyFormat: KeyFormat[KeyType] = KeyFormat.stringKeyFormat override def resourceName: String = "myResource" implicit val fields = Fields def get(id: String) = Nap.get(ctx => ???) } object MyFakeModule extends NaptimeModule { override def configure(): Unit = { bindResource[MyResource] bind[MyResource].toInstance(mock(classOf[MyResource])) bindSchemaType[Date](DataSchemaUtil.dataSchemaTypeToPrimitiveDataSchema(DataSchema.Type.LONG)) } } class OverrideTypesHelper @Inject()(val schemaOverrideTypes: NaptimeModule.SchemaTypeOverrides) } class NaptimeModuleTest extends AssertionsForJUnit { import NaptimeModuleTest._ @Test def checkInferredOverrides(): Unit = { val injector = Guice.createInjector(Stage.DEVELOPMENT, MyFakeModule, NaptimeModule) val overrides = injector.getInstance(classOf[OverrideTypesHelper]) assert(overrides.schemaOverrideTypes.size === 1) assert(overrides.schemaOverrideTypes.contains("java.util.Date")) } @Test def checkComputedOverrides(): Unit = { val injector = Guice.createInjector(Stage.DEVELOPMENT, MyFakeModule, NaptimeModule) val overrides = injector.getInstance(classOf[OverrideTypesHelper]) val routes = injector.getInstance(classOf[NaptimeRoutes]) assert(1 === routes.routerBuilders.size) val routerBuilder = routes.routerBuilders.head val inferredSchemaKeyed = routerBuilder.types.find(_.key == "org.coursera.naptime.NaptimeModuleTest.User").get assert(inferredSchemaKeyed.value.isInstanceOf[RecordDataSchema]) val userSchema = inferredSchemaKeyed.value.asInstanceOf[RecordDataSchema] assert(2 === userSchema.getFields.size()) val initialCreatedAtSchema = userSchema.getField("createdAt").getType.getDereferencedDataSchema assert(initialCreatedAtSchema.isInstanceOf[RecordDataSchema]) assert( initialCreatedAtSchema .asInstanceOf[RecordDataSchema] .getDoc .contains("Unable to infer schema")) SchemaUtils.fixupInferredSchemas(userSchema, overrides.schemaOverrideTypes) val fixedCreatedAtSchema = userSchema.getField("createdAt").getType.getDereferencedDataSchema assert(fixedCreatedAtSchema.isInstanceOf[PrimitiveDataSchema]) } }
Example 2
Source File: NettyBlockTransferServiceSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } test("can bind to a random port") { service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") { service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") { val port = 17634 service0 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests } test("can bind to a specific port twice and the second increments") { val port = 17634 service0 = createService(port) service1 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) service1.port should be (service0.port + 1) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, numCores = 1) service.init(blockDataManager) service } }
Example 3
Source File: DiskBlockManagerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.storage import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ private var rootDirs: String = _ val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) var diskBlockManager: DiskBlockManager = _ override def beforeAll() { super.beforeAll() rootDir0 = Utils.createTempDir() rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } override def afterAll() { super.afterAll() Utils.deleteRecursively(rootDir0) Utils.deleteRecursively(rootDir1) } override def beforeEach() { val conf = testConf.clone conf.set("spark.local.dir", rootDirs) diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { diskBlockManager.stop() } test("basic block creation") { val blockId = new TestBlockId("test") val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assert(diskBlockManager.containsBlock(blockId)) newFile.delete() assert(!diskBlockManager.containsBlock(blockId)) } test("enumerating blocks") { val ids = (1 to 100).map(i => TestBlockId("test_" + i)) val files = ids.map(id => diskBlockManager.getFile(id)) files.foreach(file => writeToFile(file, 10)) assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) writer.close() } }
Example 4
Source File: StagePageSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener class StagePageSuite extends SparkFunSuite with LocalSparkContext { test("peak execution memory only displayed if unsafe is enabled") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase val targetString = "peak execution memory" assert(html.contains(targetString)) // Disable unsafe and make sure it's not there val conf2 = new SparkConf(false).set(unsafeConf, "false") val html2 = renderStagePage(conf2).toString().toLowerCase assert(!html2.contains(targetString)) // Avoid setting anything; it should be displayed by default val conf3 = new SparkConf(false) val html3 = renderStagePage(conf3).toString().toLowerCase assert(html3.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values assert(html.contains("<td>10.0 b</td>" * 5)) } private def renderStagePage(conf: SparkConf): Seq[Node] = { val jobListener = new JobProgressListener(conf) val graphListener = new RDDOperationGraphListener(conf) val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) when(tab.progressListener).thenReturn(jobListener) when(tab.operationGraphListener).thenReturn(graphListener) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") val page = new StagePage(tab) // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) val peakExecutionMemory = 10 taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markSuccessful() jobListener.onTaskEnd( SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } }
Example 5
Source File: LogPageSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" val tmpOut = new File(workDir, "stdout") val tmpErr = new File(workDir, "stderr") val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory val tmpOutBad = new File(tmpDir, "stdout") val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) write(tmpOutBad, out) write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) // These files exist, but live outside the working directory val (error3, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error4, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) assert(error1.startsWith("Error: Log type must be one of ")) assert(error2.startsWith("Error: Log type must be one of ")) assert(error3.startsWith("Error: invalid log directory")) assert(error4.startsWith("Error: invalid log directory")) } private def write(f: File, s: String): Unit = { val writer = new FileWriter(f) try { writer.write(s) } finally { writer.close() } } }
Example 6
Source File: MasterWebUISuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.master.ui import java.util.Date import scala.io.Source import scala.language.postfixOps import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST.{JNothing, JString, JInt} import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.MasterStateResponse import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ import org.apache.spark.rpc.RpcEnv class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { val masterPage = mock(classOf[MasterPage]) val master = { val conf = new SparkConf val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) master } val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) before { masterWebUI.bind() } after { masterWebUI.stop() } test("list applications") { val worker = createWorkerInfo() val appDesc = createAppDesc() // use new start date so it isn't filtered by UI val activeApp = new ApplicationInfo( new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) activeApp.addExecutor(worker, 2) val workers = Array[WorkerInfo](worker) val activeApps = Array(activeApp) val completedApps = Array[ApplicationInfo]() val activeDrivers = Array[DriverInfo]() val completedDrivers = Array[DriverInfo]() val stateResponse = new MasterStateResponse( "host", 8080, None, workers, activeApps, completedApps, activeDrivers, completedDrivers, RecoveryState.ALIVE) when(masterPage.getMasterState).thenReturn(stateResponse) val resultJson = Source.fromURL( s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") .mkString val parsedJson = parse(resultJson) val firstApp = parsedJson(0) assert(firstApp \ "id" === JString(activeApp.id)) assert(firstApp \ "name" === JString(activeApp.desc.name)) assert(firstApp \ "coresGranted" === JInt(2)) assert(firstApp \ "maxCores" === JInt(4)) assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) assert(firstApp \ "coresPerExecutor" === JNothing) } }
Example 7
Source File: NettyBlockTransferServiceSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } test("can bind to a random port") { service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") { service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") { val port = 17634 service0 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests } test("can bind to a specific port twice and the second increments") { val port = 17634 service0 = createService(port) service1 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) service1.port should be (service0.port + 1) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, numCores = 1) service.init(blockDataManager) service } }
Example 8
Source File: LogPageSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.worker.Worker class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) val worker = mock(classOf[Worker]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) when(webui.worker).thenReturn(worker) when(worker.conf).thenReturn(new SparkConf()) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" val tmpOut = new File(workDir, "stdout") val tmpErr = new File(workDir, "stderr") val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory val tmpOutBad = new File(tmpDir, "stdout") val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) write(tmpOutBad, out) write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) // These files exist, but live outside the working directory val (error3, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error4, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) assert(error1.startsWith("Error: Log type must be one of ")) assert(error2.startsWith("Error: Log type must be one of ")) assert(error3.startsWith("Error: invalid log directory")) assert(error4.startsWith("Error: invalid log directory")) } private def write(f: File, s: String): Unit = { val writer = new FileWriter(f) try { writer.write(s) } finally { writer.close() } } }
Example 9
Source File: MasterWebUISuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.master.ui import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.Date import scala.collection.mutable.HashMap import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver} import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf val securityMgr = new SecurityManager(conf) val rpcEnv = mock(classOf[RpcEnv]) val master = mock(classOf[Master]) val masterEndpointRef = mock(classOf[RpcEndpointRef]) when(master.securityMgr).thenReturn(securityMgr) when(master.conf).thenReturn(conf) when(master.rpcEnv).thenReturn(rpcEnv) when(master.self).thenReturn(masterEndpointRef) val masterWebUI = new MasterWebUI(master, 0) override def beforeAll() { super.beforeAll() masterWebUI.bind() } override def afterAll() { masterWebUI.stop() super.afterAll() } test("kill application") { val appDesc = createAppDesc() // use new start date so it isn't filtered by UI val activeApp = new ApplicationInfo( new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue) when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp))) val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/" val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify the master was called to remove the active app verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED) } test("kill driver") { val activeDriverId = "driver-0" val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/" val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify that master was asked to kill driver with the correct id verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId)) } private def convPostDataToString(data: Map[String, String]): String = { (for ((name, value) <- data) yield s"$name=$value").mkString("&") } private def sendHttpRequest( url: String, method: String, body: String = ""): HttpURLConnection = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod(method) if (body.nonEmpty) { conn.setDoOutput(true) conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded") conn.setRequestProperty("Content-Length", Integer.toString(body.length)) val out = new DataOutputStream(conn.getOutputStream) out.write(body.getBytes(StandardCharsets.UTF_8)) out.close() } conn } }
Example 10
Source File: NettyBlockTransferServiceSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import scala.util.Random import org.mockito.Mockito.mock import org.scalatest._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.network.BlockDataManager class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with Matchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { try { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } finally { super.afterEach() } } test("can bind to a random port") { service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") { service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) } test("can bind to a specific port twice and the second increments") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) service1 = createService(service0.port) // `service0.port` is occupied, so `service1.port` should not be `service0.port` verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port) } private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests // the default value for `spark.port.maxRetries` is 100 under test actualPort should be <= (expectedPort + 100) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost", port, 1) service.init(blockDataManager) service } }
Example 11
Source File: DiskBlockManagerSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.storage import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils //DiskBlockManager管理和维护了逻辑上的Block和存储在Disk上的物理的Block的映射。 //一般来说,一个逻辑的Block会根据它的BlockId生成的名字映射到一个物理上的文件 class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ private var rootDirs: String = _ val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出 var diskBlockManager: DiskBlockManager = _ override def beforeAll() { super.beforeAll() rootDir0 = Utils.createTempDir() rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } override def afterAll() { super.afterAll() Utils.deleteRecursively(rootDir0) Utils.deleteRecursively(rootDir1) } override def beforeEach() { val conf = testConf.clone conf.set("spark.local.dir", rootDirs) diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { diskBlockManager.stop() } test("basic block creation") {//基本块的创建 val blockId = new TestBlockId("test") //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出 val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assert(diskBlockManager.containsBlock(blockId)) newFile.delete() assert(!diskBlockManager.containsBlock(blockId)) } test("enumerating blocks") {//枚举块 val ids = (1 to 100).map(i => TestBlockId("test_" + i)) val files = ids.map(id => diskBlockManager.getFile(id)) files.foreach(file => writeToFile(file, 10)) assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) writer.close() } }
Example 12
Source File: StagePageSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener class StagePageSuite extends SparkFunSuite with LocalSparkContext { //仅在启用不安全时才显示执行内存值 test("peak execution memory only displayed if unsafe is enabled") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase println("===="+html) val targetString = "peak execution memory" assert(html.contains(targetString)) // Disable unsafe and make sure it's not there //禁用不安全的,并确保它不在那里 val conf2 = new SparkConf(false).set(unsafeConf, "false") val html2 = renderStagePage(conf2).toString().toLowerCase assert(!html2.contains(targetString)) // Avoid setting anything; it should be displayed by default //避免设置任何东西,它应该默认显示 val conf3 = new SparkConf(false) val html3 = renderStagePage(conf3).toString().toLowerCase assert(html3.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values //验证min / 25/50/75 / max显示任务值不是累积值 assert(html.contains("<td>10.0 b</td>" * 5)) } private def renderStagePage(conf: SparkConf): Seq[Node] = { val jobListener = new JobProgressListener(conf) val graphListener = new RDDOperationGraphListener(conf) val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) when(tab.progressListener).thenReturn(jobListener) when(tab.operationGraphListener).thenReturn(graphListener) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") val page = new StagePage(tab) // Simulate a stage in job progress listener //在工作进度侦听器中模拟一个阶段 val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) val peakExecutionMemory = 10 taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markSuccessful() jobListener.onTaskEnd( SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } }
Example 13
Source File: NettyBlockTransferServiceSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } test("can bind to a random port") {//可以绑定到一个随机端口 service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") {//可以绑定到两个随机端口 service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") {//可以绑定到一个特定的端口 val port = 17634 service0 = createService(port) service0.port should be >= port //在同时测试的情况下避免测试平等 service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests } //可以绑定到一个特定的端口两次和第二个增量 test("can bind to a specific port twice and the second increments") { val port = 17634 service0 = createService(port) service1 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) service1.port should be (service0.port + 1) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, numCores = 1) service.init(blockDataManager) service } }
Example 14
Source File: DiskBlockManagerSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.storage import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ private var rootDirs: String = _ val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) var diskBlockManager: DiskBlockManager = _ override def beforeAll() { super.beforeAll() rootDir0 = Utils.createTempDir() rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } override def afterAll() { super.afterAll() Utils.deleteRecursively(rootDir0) Utils.deleteRecursively(rootDir1) } override def beforeEach() { val conf = testConf.clone conf.set("spark.local.dir", rootDirs) diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { diskBlockManager.stop() } test("basic block creation") { val blockId = new TestBlockId("test") val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assert(diskBlockManager.containsBlock(blockId)) newFile.delete() assert(!diskBlockManager.containsBlock(blockId)) } test("enumerating blocks") { val ids = (1 to 100).map(i => TestBlockId("test_" + i)) val files = ids.map(id => diskBlockManager.getFile(id)) files.foreach(file => writeToFile(file, 10)) assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) writer.close() } }
Example 15
Source File: HeartbeatReceiverSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ import org.apache.spark.scheduler.TaskScheduler import org.apache.spark.util.RpcUtils import org.scalatest.concurrent.Eventually._ class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { test("HeartbeatReceiver") { sc = spy(new SparkContext("local[2]", "test")) val scheduler = mock(classOf[TaskScheduler]) when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) when(sc.taskScheduler).thenReturn(scheduler) val heartbeatReceiver = new HeartbeatReceiver(sc) sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) eventually(timeout(5 seconds), interval(5 millis)) { assert(heartbeatReceiver.scheduler != null) } val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) assert(false === response.reregisterBlockManager) } test("HeartbeatReceiver re-register") { sc = spy(new SparkContext("local[2]", "test")) val scheduler = mock(classOf[TaskScheduler]) when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) when(sc.taskScheduler).thenReturn(scheduler) val heartbeatReceiver = new HeartbeatReceiver(sc) sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) eventually(timeout(5 seconds), interval(5 millis)) { assert(heartbeatReceiver.scheduler != null) } val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) assert(true === response.reregisterBlockManager) } }
Example 16
Source File: LogPageSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" val tmpOut = new File(workDir, "stdout") val tmpErr = new File(workDir, "stderr") val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory val tmpOutBad = new File(tmpDir, "stdout") val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) write(tmpOutBad, out) write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) // These files exist, but live outside the working directory val (error3, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error4, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) assert(error1.startsWith("Error: Log type must be one of ")) assert(error2.startsWith("Error: Log type must be one of ")) assert(error3.startsWith("Error: invalid log directory")) assert(error4.startsWith("Error: invalid log directory")) } private def write(f: File, s: String): Unit = { val writer = new FileWriter(f) try { writer.write(s) } finally { writer.close() } } }
Example 17
Source File: NettyBlockTransferServiceSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import scala.util.Random import org.mockito.Mockito.mock import org.scalatest._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.config._ import org.apache.spark.network.BlockDataManager class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { try { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } finally { super.afterEach() } } test("can bind to a random port") { service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") { service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) } test("can bind to a specific port twice and the second increments") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) service1 = createService(service0.port) // `service0.port` is occupied, so `service1.port` should not be `service0.port` verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port) } private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests actualPort should be <= (expectedPort + 10) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost", port, 1) service.init(blockDataManager) service } }
Example 18
Source File: StagePageSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.mockito.Mockito.{RETURNS_SMART_NULLS, mock, when} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.exec.ExecutorsListener import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener import org.apache.spark.util.Utils class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 test("peak execution memory should displayed") { val conf = new SparkConf(false) val html = renderStagePage(conf).toString().toLowerCase val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val conf = new SparkConf(false) val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5)) } private def renderStagePage(conf: SparkConf): Seq[Node] = { val jobListener = new JobProgressListener(conf, Utils.getCurrentUserName()) val graphListener = new RDDOperationGraphListener(conf) val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf) val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) when(tab.progressListener).thenReturn(jobListener) when(tab.operationGraphListener).thenReturn(graphListener) when(tab.executorsListener).thenReturn(executorsListener) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") val page = new StagePage(tab) // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markFinished(TaskState.FINISHED) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } }
Example 19
Source File: LogPageSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.worker.Worker class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) val worker = mock(classOf[Worker]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) when(webui.worker).thenReturn(worker) when(worker.conf).thenReturn(new SparkConf()) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" val tmpOut = new File(workDir, "stdout") val tmpErr = new File(workDir, "stderr") val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory val tmpOutBad = new File(tmpDir, "stdout") val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) write(tmpOutBad, out) write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) // These files exist, but live outside the working directory val (error3, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error4, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) assert(error1.startsWith("Error: Log type must be one of ")) assert(error2.startsWith("Error: Log type must be one of ")) assert(error3.startsWith("Error: invalid log directory")) assert(error4.startsWith("Error: invalid log directory")) } private def write(f: File, s: String): Unit = { val writer = new FileWriter(f) try { writer.write(s) } finally { writer.close() } } }
Example 20
Source File: MasterWebUISuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.master.ui import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.Date import scala.collection.mutable.HashMap import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver} import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf val securityMgr = new SecurityManager(conf) val rpcEnv = mock(classOf[RpcEnv]) val master = mock(classOf[Master]) val masterEndpointRef = mock(classOf[RpcEndpointRef]) when(master.securityMgr).thenReturn(securityMgr) when(master.conf).thenReturn(conf) when(master.rpcEnv).thenReturn(rpcEnv) when(master.self).thenReturn(masterEndpointRef) val masterWebUI = new MasterWebUI(master, 0) override def beforeAll() { super.beforeAll() masterWebUI.bind() } override def afterAll() { masterWebUI.stop() super.afterAll() } test("kill application") { val appDesc = createAppDesc() // use new start date so it isn't filtered by UI val activeApp = new ApplicationInfo( new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue) when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp))) val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/" val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify the master was called to remove the active app verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED) } test("kill driver") { val activeDriverId = "driver-0" val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/" val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify that master was asked to kill driver with the correct id verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId)) } private def convPostDataToString(data: Map[String, String]): String = { (for ((name, value) <- data) yield s"$name=$value").mkString("&") } private def sendHttpRequest( url: String, method: String, body: String = ""): HttpURLConnection = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod(method) if (body.nonEmpty) { conn.setDoOutput(true) conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded") conn.setRequestProperty("Content-Length", Integer.toString(body.length)) val out = new DataOutputStream(conn.getOutputStream) out.write(body.getBytes(StandardCharsets.UTF_8)) out.close() } conn } }
Example 21
Source File: NettyBlockTransferServiceSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import scala.util.Random import org.mockito.Mockito.mock import org.scalatest._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.config._ import org.apache.spark.network.BlockDataManager class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { try { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } finally { super.afterEach() } } test("can bind to a random port") { service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") { service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) } test("can bind to a specific port twice and the second increments") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) service1 = createService(service0.port) // `service0.port` is occupied, so `service1.port` should not be `service0.port` verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port) } private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests actualPort should be <= (expectedPort + 10) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost", port, 1) service.init(blockDataManager) service } }
Example 22
Source File: DiskBlockManagerSuite.scala From SparkCore with Apache License 2.0 | 5 votes |
package org.apache.spark.storage import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf import org.apache.spark.util.Utils class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ private var rootDirs: String = _ val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) var diskBlockManager: DiskBlockManager = _ override def beforeAll() { super.beforeAll() rootDir0 = Utils.createTempDir() rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } override def afterAll() { super.afterAll() Utils.deleteRecursively(rootDir0) Utils.deleteRecursively(rootDir1) } override def beforeEach() { val conf = testConf.clone conf.set("spark.local.dir", rootDirs) diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { diskBlockManager.stop() } test("basic block creation") { val blockId = new TestBlockId("test") val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assert(diskBlockManager.containsBlock(blockId)) newFile.delete() assert(!diskBlockManager.containsBlock(blockId)) } test("enumerating blocks") { val ids = (1 to 100).map(i => TestBlockId("test_" + i)) val files = ids.map(id => diskBlockManager.getFile(id)) files.foreach(file => writeToFile(file, 10)) assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) writer.close() } }
Example 23
Source File: StagePageSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.exec.ExecutorsListener import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 test("peak execution memory should displayed") { val conf = new SparkConf(false) val html = renderStagePage(conf).toString().toLowerCase val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val conf = new SparkConf(false) val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5)) } private def renderStagePage(conf: SparkConf): Seq[Node] = { val jobListener = new JobProgressListener(conf) val graphListener = new RDDOperationGraphListener(conf) val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf) val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) when(tab.progressListener).thenReturn(jobListener) when(tab.operationGraphListener).thenReturn(graphListener) when(tab.executorsListener).thenReturn(executorsListener) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") val page = new StagePage(tab) // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markFinished(TaskState.FINISHED) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } }
Example 24
Source File: LogPageSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.worker.Worker class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) val worker = mock(classOf[Worker]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) when(webui.worker).thenReturn(worker) when(worker.conf).thenReturn(new SparkConf()) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" val tmpOut = new File(workDir, "stdout") val tmpErr = new File(workDir, "stderr") val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory val tmpOutBad = new File(tmpDir, "stdout") val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) write(tmpOutBad, out) write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) // These files exist, but live outside the working directory val (error3, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error4, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) assert(error1.startsWith("Error: Log type must be one of ")) assert(error2.startsWith("Error: Log type must be one of ")) assert(error3.startsWith("Error: invalid log directory")) assert(error4.startsWith("Error: invalid log directory")) } private def write(f: File, s: String): Unit = { val writer = new FileWriter(f) try { writer.write(s) } finally { writer.close() } } }
Example 25
Source File: MasterWebUISuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.master.ui import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.Date import scala.collection.mutable.HashMap import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver} import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf val securityMgr = new SecurityManager(conf) val rpcEnv = mock(classOf[RpcEnv]) val master = mock(classOf[Master]) val masterEndpointRef = mock(classOf[RpcEndpointRef]) when(master.securityMgr).thenReturn(securityMgr) when(master.conf).thenReturn(conf) when(master.rpcEnv).thenReturn(rpcEnv) when(master.self).thenReturn(masterEndpointRef) val masterWebUI = new MasterWebUI(master, 0) override def beforeAll() { super.beforeAll() masterWebUI.bind() } override def afterAll() { masterWebUI.stop() super.afterAll() } test("kill application") { val appDesc = createAppDesc() // use new start date so it isn't filtered by UI val activeApp = new ApplicationInfo( new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue) when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp))) val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/" val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify the master was called to remove the active app verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED) } test("kill driver") { val activeDriverId = "driver-0" val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/" val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify that master was asked to kill driver with the correct id verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId)) } private def convPostDataToString(data: Map[String, String]): String = { (for ((name, value) <- data) yield s"$name=$value").mkString("&") } private def sendHttpRequest( url: String, method: String, body: String = ""): HttpURLConnection = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod(method) if (body.nonEmpty) { conn.setDoOutput(true) conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded") conn.setRequestProperty("Content-Length", Integer.toString(body.length)) val out = new DataOutputStream(conn.getOutputStream) out.write(body.getBytes(StandardCharsets.UTF_8)) out.close() } conn } }
Example 26
Source File: NettyBlockTransferServiceSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import scala.util.Random import org.mockito.Mockito.mock import org.scalatest._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.config._ import org.apache.spark.network.BlockDataManager class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { try { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } finally { super.afterEach() } } test("can bind to a random port") { service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") { service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) } test("can bind to a specific port twice and the second increments") { val port = 17634 + Random.nextInt(10000) logInfo("random port for test: " + port) service0 = createService(port) verifyServicePort(expectedPort = port, actualPort = service0.port) service1 = createService(service0.port) // `service0.port` is occupied, so `service1.port` should not be `service0.port` verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port) } private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests actualPort should be <= (expectedPort + 10) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost", port, 1) service.init(blockDataManager) service } }
Example 27
Source File: StagePageSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.exec.ExecutorsListener import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 test("peak execution memory only displayed if unsafe is enabled") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase val targetString = "peak execution memory" assert(html.contains(targetString)) // Disable unsafe and make sure it's not there val conf2 = new SparkConf(false).set(unsafeConf, "false") val html2 = renderStagePage(conf2).toString().toLowerCase assert(!html2.contains(targetString)) // Avoid setting anything; it should be displayed by default val conf3 = new SparkConf(false) val html3 = renderStagePage(conf3).toString().toLowerCase assert(html3.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5)) } private def renderStagePage(conf: SparkConf): Seq[Node] = { val jobListener = new JobProgressListener(conf) val graphListener = new RDDOperationGraphListener(conf) val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf) val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) when(tab.progressListener).thenReturn(jobListener) when(tab.operationGraphListener).thenReturn(graphListener) when(tab.executorsListener).thenReturn(executorsListener) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") val page = new StagePage(tab) // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markFinished(TaskState.FINISHED) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } }
Example 28
Source File: LogPageSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.worker.Worker class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) val worker = mock(classOf[Worker]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) when(webui.worker).thenReturn(worker) when(worker.conf).thenReturn(new SparkConf()) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" val tmpOut = new File(workDir, "stdout") val tmpErr = new File(workDir, "stderr") val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory val tmpOutBad = new File(tmpDir, "stdout") val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) write(tmpOutBad, out) write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) // These files exist, but live outside the working directory val (error3, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error4, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) assert(error1.startsWith("Error: Log type must be one of ")) assert(error2.startsWith("Error: Log type must be one of ")) assert(error3.startsWith("Error: invalid log directory")) assert(error4.startsWith("Error: invalid log directory")) } private def write(f: File, s: String): Unit = { val writer = new FileWriter(f) try { writer.write(s) } finally { writer.close() } } }
Example 29
Source File: MasterWebUISuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.master.ui import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.Date import scala.collection.mutable.HashMap import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver} import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf val securityMgr = new SecurityManager(conf) val rpcEnv = mock(classOf[RpcEnv]) val master = mock(classOf[Master]) val masterEndpointRef = mock(classOf[RpcEndpointRef]) when(master.securityMgr).thenReturn(securityMgr) when(master.conf).thenReturn(conf) when(master.rpcEnv).thenReturn(rpcEnv) when(master.self).thenReturn(masterEndpointRef) val masterWebUI = new MasterWebUI(master, 0) override def beforeAll() { super.beforeAll() masterWebUI.bind() } override def afterAll() { masterWebUI.stop() super.afterAll() } test("kill application") { val appDesc = createAppDesc() // use new start date so it isn't filtered by UI val activeApp = new ApplicationInfo( new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue) when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp))) val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/" val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify the master was called to remove the active app verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED) } test("kill driver") { val activeDriverId = "driver-0" val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/" val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true"))) val conn = sendHttpRequest(url, "POST", body) conn.getResponseCode // Verify that master was asked to kill driver with the correct id verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId)) } private def convPostDataToString(data: Map[String, String]): String = { (for ((name, value) <- data) yield s"$name=$value").mkString("&") } private def sendHttpRequest( url: String, method: String, body: String = ""): HttpURLConnection = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod(method) if (body.nonEmpty) { conn.setDoOutput(true) conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded") conn.setRequestProperty("Content-Length", Integer.toString(body.length)) val out = new DataOutputStream(conn.getOutputStream) out.write(body.getBytes(StandardCharsets.UTF_8)) out.close() } conn } }