org.mockito.Mockito.mock Scala Examples

The following examples show how to use org.mockito.Mockito.mock. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: NaptimeModuleTest.scala    From naptime   with Apache License 2.0 8 votes vote down vote up
package org.coursera.naptime

import java.util.Date
import javax.inject.Inject

import akka.stream.Materializer
import com.google.inject.Guice
import com.google.inject.Stage
import com.linkedin.data.schema.DataSchema
import com.linkedin.data.schema.DataSchemaUtil
import com.linkedin.data.schema.PrimitiveDataSchema
import com.linkedin.data.schema.RecordDataSchema
import org.coursera.naptime.model.KeyFormat
import org.coursera.naptime.resources.TopLevelCollectionResource
import org.coursera.naptime.router2.NaptimeRoutes
import org.junit.Test
import org.mockito.Mockito.mock
import org.scalatest.junit.AssertionsForJUnit
import play.api.libs.json.Json
import play.api.libs.json.OFormat

import scala.concurrent.ExecutionContext

object NaptimeModuleTest {
  case class User(name: String, createdAt: Date)
  object User {
    implicit val oFormat: OFormat[User] = Json.format[User]
  }
  class MyResource(implicit val executionContext: ExecutionContext, val materializer: Materializer)
      extends TopLevelCollectionResource[String, User] {
    override implicit def resourceFormat: OFormat[User] = User.oFormat
    override def keyFormat: KeyFormat[KeyType] = KeyFormat.stringKeyFormat
    override def resourceName: String = "myResource"
    implicit val fields = Fields

    def get(id: String) = Nap.get(ctx => ???)
  }
  object MyFakeModule extends NaptimeModule {
    override def configure(): Unit = {
      bindResource[MyResource]
      bind[MyResource].toInstance(mock(classOf[MyResource]))
      bindSchemaType[Date](DataSchemaUtil.dataSchemaTypeToPrimitiveDataSchema(DataSchema.Type.LONG))
    }
  }

  class OverrideTypesHelper @Inject()(val schemaOverrideTypes: NaptimeModule.SchemaTypeOverrides)
}

class NaptimeModuleTest extends AssertionsForJUnit {
  import NaptimeModuleTest._

  
  @Test
  def checkInferredOverrides(): Unit = {
    val injector = Guice.createInjector(Stage.DEVELOPMENT, MyFakeModule, NaptimeModule)
    val overrides = injector.getInstance(classOf[OverrideTypesHelper])
    assert(overrides.schemaOverrideTypes.size === 1)
    assert(overrides.schemaOverrideTypes.contains("java.util.Date"))
  }

  @Test
  def checkComputedOverrides(): Unit = {
    val injector = Guice.createInjector(Stage.DEVELOPMENT, MyFakeModule, NaptimeModule)
    val overrides = injector.getInstance(classOf[OverrideTypesHelper])
    val routes = injector.getInstance(classOf[NaptimeRoutes])
    assert(1 === routes.routerBuilders.size)
    val routerBuilder = routes.routerBuilders.head
    val inferredSchemaKeyed =
      routerBuilder.types.find(_.key == "org.coursera.naptime.NaptimeModuleTest.User").get
    assert(inferredSchemaKeyed.value.isInstanceOf[RecordDataSchema])
    val userSchema = inferredSchemaKeyed.value.asInstanceOf[RecordDataSchema]
    assert(2 === userSchema.getFields.size())
    val initialCreatedAtSchema = userSchema.getField("createdAt").getType.getDereferencedDataSchema
    assert(initialCreatedAtSchema.isInstanceOf[RecordDataSchema])
    assert(
      initialCreatedAtSchema
        .asInstanceOf[RecordDataSchema]
        .getDoc
        .contains("Unable to infer schema"))
    SchemaUtils.fixupInferredSchemas(userSchema, overrides.schemaOverrideTypes)
    val fixedCreatedAtSchema = userSchema.getField("createdAt").getType.getDereferencedDataSchema
    assert(fixedCreatedAtSchema.isInstanceOf[PrimitiveDataSchema])
  }
} 
Example 2
Source File: NettyBlockTransferServiceSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import org.apache.spark.network.BlockDataManager
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.mockito.Mockito.mock
import org.scalatest._

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    if (service0 != null) {
      service0.close()
      service0 = null
    }

    if (service1 != null) {
      service1.close()
      service1 = null
    }
  }

  test("can bind to a random port") {
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {
    val port = 17634
    service0 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests
  }

  test("can bind to a specific port twice and the second increments") {
    val port = 17634
    service0 = createService(port)
    service1 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10)
    service1.port should be (service0.port + 1)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
      .set("spark.blockManager.port", port.toString)
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, numCores = 1)
    service.init(blockDataManager)
    service
  }
} 
Example 3
Source File: DiskBlockManagerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 4
Source File: StagePageSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  test("peak execution memory only displayed if unsafe is enabled") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
    // Disable unsafe and make sure it's not there
    val conf2 = new SparkConf(false).set(unsafeConf, "false")
    val html2 = renderStagePage(conf2).toString().toLowerCase
    assert(!html2.contains(targetString))
    // Avoid setting anything; it should be displayed by default
    val conf3 = new SparkConf(false)
    val html3 = renderStagePage(conf3).toString().toLowerCase
    assert(html3.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains("<td>10.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        val peakExecutionMemory = 10
        taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY,
          Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markSuccessful()
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 5
Source File: LogPageSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 6
Source File: MasterWebUISuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.util.Date

import scala.io.Source
import scala.language.postfixOps

import org.json4s.jackson.JsonMethods._
import org.json4s.JsonAST.{JNothing, JString, JInt}
import org.mockito.Mockito.{mock, when}
import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.MasterStateResponse
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.RpcEnv


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter {

  val masterPage = mock(classOf[MasterPage])
  val master = {
    val conf = new SparkConf
    val securityMgr = new SecurityManager(conf)
    val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr)
    val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf)
    master
  }
  val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage))

  before {
    masterWebUI.bind()
  }

  after {
    masterWebUI.stop()
  }

  test("list applications") {
    val worker = createWorkerInfo()
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue)
    activeApp.addExecutor(worker, 2)

    val workers = Array[WorkerInfo](worker)
    val activeApps = Array(activeApp)
    val completedApps = Array[ApplicationInfo]()
    val activeDrivers = Array[DriverInfo]()
    val completedDrivers = Array[DriverInfo]()
    val stateResponse = new MasterStateResponse(
      "host", 8080, None, workers, activeApps, completedApps,
      activeDrivers, completedDrivers, RecoveryState.ALIVE)

    when(masterPage.getMasterState).thenReturn(stateResponse)

    val resultJson = Source.fromURL(
      s"http://localhost:${masterWebUI.boundPort}/api/v1/applications")
      .mkString
    val parsedJson = parse(resultJson)
    val firstApp = parsedJson(0)

    assert(firstApp \ "id" === JString(activeApp.id))
    assert(firstApp \ "name" === JString(activeApp.desc.name))
    assert(firstApp \ "coresGranted" === JInt(2))
    assert(firstApp \ "maxCores" === JInt(4))
    assert(firstApp \ "memoryPerExecutorMB" === JInt(1234))
    assert(firstApp \ "coresPerExecutor" === JNothing)
  }

} 
Example 7
Source File: NettyBlockTransferServiceSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import org.apache.spark.network.BlockDataManager
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.mockito.Mockito.mock
import org.scalatest._

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    if (service0 != null) {
      service0.close()
      service0 = null
    }

    if (service1 != null) {
      service1.close()
      service1 = null
    }
  }

  test("can bind to a random port") {
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {
    val port = 17634
    service0 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests
  }

  test("can bind to a specific port twice and the second increments") {
    val port = 17634
    service0 = createService(port)
    service1 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10)
    service1.port should be (service0.port + 1)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
      .set("spark.blockManager.port", port.toString)
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, numCores = 1)
    service.init(blockDataManager)
    service
  }
} 
Example 8
Source File: LogPageSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 9
Source File: MasterWebUISuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 10
Source File: NettyBlockTransferServiceSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import scala.util.Random

import org.mockito.Mockito.mock
import org.scalatest._

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.network.BlockDataManager

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with Matchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    try {
      if (service0 != null) {
        service0.close()
        service0 = null
      }

      if (service1 != null) {
        service1.close()
        service1 = null
      }
    } finally {
      super.afterEach()
    }
  }

  test("can bind to a random port") {
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
  }

  test("can bind to a specific port twice and the second increments") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
    service1 = createService(service0.port)
    // `service0.port` is occupied, so `service1.port` should not be `service0.port`
    verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port)
  }

  private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = {
    actualPort should be >= expectedPort
    // avoid testing equality in case of simultaneous tests
    // the default value for `spark.port.maxRetries` is 100 under test
    actualPort should be <= (expectedPort + 100)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost",
      port, 1)
    service.init(blockDataManager)
    service
  }
} 
Example 11
Source File: DiskBlockManagerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils
//DiskBlockManager管理和维护了逻辑上的Block和存储在Disk上的物理的Block的映射。
//一般来说,一个逻辑的Block会根据它的BlockId生成的名字映射到一个物理上的文件
class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {//基本块的创建
    val blockId = new TestBlockId("test")
    //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {//枚举块
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 12
Source File: StagePageSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {
  //仅在启用不安全时才显示执行内存值
  test("peak execution memory only displayed if unsafe is enabled") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    println("===="+html)
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
    // Disable unsafe and make sure it's not there
    //禁用不安全的,并确保它不在那里
    val conf2 = new SparkConf(false).set(unsafeConf, "false")
    val html2 = renderStagePage(conf2).toString().toLowerCase
    assert(!html2.contains(targetString))
    // Avoid setting anything; it should be displayed by default
    //避免设置任何东西,它应该默认显示
    val conf3 = new SparkConf(false)
    val html3 = renderStagePage(conf3).toString().toLowerCase
    assert(html3.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    //验证min / 25/50/75 / max显示任务值不是累积值
    assert(html.contains("<td>10.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    //在工作进度侦听器中模拟一个阶段
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        val peakExecutionMemory = 10
        taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY,
          Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markSuccessful()
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 13
Source File: NettyBlockTransferServiceSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import org.apache.spark.network.BlockDataManager
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.mockito.Mockito.mock
import org.scalatest._

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    if (service0 != null) {
      service0.close()
      service0 = null
    }

    if (service1 != null) {
      service1.close()
      service1 = null
    }
  }

  test("can bind to a random port") {//可以绑定到一个随机端口
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {//可以绑定到两个随机端口
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {//可以绑定到一个特定的端口
    val port = 17634
    service0 = createService(port)
    service0.port should be >= port
    //在同时测试的情况下避免测试平等
    service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests
  }
  //可以绑定到一个特定的端口两次和第二个增量
  test("can bind to a specific port twice and the second increments") {
    val port = 17634
    service0 = createService(port)
    service1 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10)
    service1.port should be (service0.port + 1)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
      .set("spark.blockManager.port", port.toString)
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, numCores = 1)
    service.init(blockDataManager)
    service
  }
} 
Example 14
Source File: DiskBlockManagerSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 15
Source File: HeartbeatReceiverSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration._
import scala.language.postfixOps

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
import org.mockito.Mockito.{mock, spy, verify, when}
import org.mockito.Matchers
import org.mockito.Matchers._

import org.apache.spark.scheduler.TaskScheduler
import org.apache.spark.util.RpcUtils
import org.scalatest.concurrent.Eventually._

class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext {

  test("HeartbeatReceiver") {
    sc = spy(new SparkContext("local[2]", "test"))
    val scheduler = mock(classOf[TaskScheduler])
    when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
    when(sc.taskScheduler).thenReturn(scheduler)

    val heartbeatReceiver = new HeartbeatReceiver(sc)
    sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
    eventually(timeout(5 seconds), interval(5 millis)) {
      assert(heartbeatReceiver.scheduler != null)
    }
    val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)

    val metrics = new TaskMetrics
    val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
    val response = receiverRef.askWithRetry[HeartbeatResponse](
      Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))

    verify(scheduler).executorHeartbeatReceived(
      Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
    assert(false === response.reregisterBlockManager)
  }

  test("HeartbeatReceiver re-register") {
    sc = spy(new SparkContext("local[2]", "test"))
    val scheduler = mock(classOf[TaskScheduler])
    when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
    when(sc.taskScheduler).thenReturn(scheduler)

    val heartbeatReceiver = new HeartbeatReceiver(sc)
    sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
    eventually(timeout(5 seconds), interval(5 millis)) {
      assert(heartbeatReceiver.scheduler != null)
    }
    val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)

    val metrics = new TaskMetrics
    val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
    val response = receiverRef.askWithRetry[HeartbeatResponse](
      Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))

    verify(scheduler).executorHeartbeatReceived(
      Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
    assert(true === response.reregisterBlockManager)
  }
} 
Example 16
Source File: LogPageSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 17
Source File: NettyBlockTransferServiceSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import scala.util.Random

import org.mockito.Mockito.mock
import org.scalatest._

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.internal.config._
import org.apache.spark.network.BlockDataManager

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    try {
      if (service0 != null) {
        service0.close()
        service0 = null
      }

      if (service1 != null) {
        service1.close()
        service1 = null
      }
    } finally {
      super.afterEach()
    }
  }

  test("can bind to a random port") {
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
  }

  test("can bind to a specific port twice and the second increments") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
    service1 = createService(service0.port)
    // `service0.port` is occupied, so `service1.port` should not be `service0.port`
    verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port)
  }

  private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = {
    actualPort should be >= expectedPort
    // avoid testing equality in case of simultaneous tests
    actualPort should be <= (expectedPort + 10)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost",
      port, 1)
    service.init(blockDataManager)
    service
  }
} 
Example 18
Source File: StagePageSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{RETURNS_SMART_NULLS, mock, when}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.exec.ExecutorsListener
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener
import org.apache.spark.util.Utils

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  private val peakExecutionMemory = 10

  test("peak execution memory should displayed") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {

    val jobListener = new JobProgressListener(conf, Utils.getCurrentUserName())
    val graphListener = new RDDOperationGraphListener(conf)
    val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.executorsListener).thenReturn(executorsListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markFinished(TaskState.FINISHED)
        val taskMetrics = TaskMetrics.empty
        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 19
Source File: LogPageSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 20
Source File: MasterWebUISuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 21
Source File: NettyBlockTransferServiceSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import scala.util.Random

import org.mockito.Mockito.mock
import org.scalatest._

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.internal.config._
import org.apache.spark.network.BlockDataManager

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    try {
      if (service0 != null) {
        service0.close()
        service0 = null
      }

      if (service1 != null) {
        service1.close()
        service1 = null
      }
    } finally {
      super.afterEach()
    }
  }

  test("can bind to a random port") {
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
  }

  test("can bind to a specific port twice and the second increments") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
    service1 = createService(service0.port)
    // `service0.port` is occupied, so `service1.port` should not be `service0.port`
    verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port)
  }

  private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = {
    actualPort should be >= expectedPort
    // avoid testing equality in case of simultaneous tests
    actualPort should be <= (expectedPort + 10)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost",
      port, 1)
    service.init(blockDataManager)
    service
  }
} 
Example 22
Source File: DiskBlockManagerSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

import org.apache.spark.SparkConf
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 23
Source File: StagePageSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.exec.ExecutorsListener
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  private val peakExecutionMemory = 10

  test("peak execution memory should displayed") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.executorsListener).thenReturn(executorsListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markFinished(TaskState.FINISHED)
        val taskMetrics = TaskMetrics.empty
        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 24
Source File: LogPageSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 25
Source File: MasterWebUISuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 26
Source File: NettyBlockTransferServiceSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import scala.util.Random

import org.mockito.Mockito.mock
import org.scalatest._

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.internal.config._
import org.apache.spark.network.BlockDataManager

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    try {
      if (service0 != null) {
        service0.close()
        service0 = null
      }

      if (service1 != null) {
        service1.close()
        service1 = null
      }
    } finally {
      super.afterEach()
    }
  }

  test("can bind to a random port") {
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
  }

  test("can bind to a specific port twice and the second increments") {
    val port = 17634 + Random.nextInt(10000)
    logInfo("random port for test: " + port)
    service0 = createService(port)
    verifyServicePort(expectedPort = port, actualPort = service0.port)
    service1 = createService(service0.port)
    // `service0.port` is occupied, so `service1.port` should not be `service0.port`
    verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port)
  }

  private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = {
    actualPort should be >= expectedPort
    // avoid testing equality in case of simultaneous tests
    actualPort should be <= (expectedPort + 10)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost",
      port, 1)
    service.init(blockDataManager)
    service
  }
} 
Example 27
Source File: StagePageSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.exec.ExecutorsListener
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  private val peakExecutionMemory = 10

  test("peak execution memory only displayed if unsafe is enabled") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
    // Disable unsafe and make sure it's not there
    val conf2 = new SparkConf(false).set(unsafeConf, "false")
    val html2 = renderStagePage(conf2).toString().toLowerCase
    assert(!html2.contains(targetString))
    // Avoid setting anything; it should be displayed by default
    val conf3 = new SparkConf(false)
    val html3 = renderStagePage(conf3).toString().toLowerCase
    assert(html3.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.executorsListener).thenReturn(executorsListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markFinished(TaskState.FINISHED)
        val taskMetrics = TaskMetrics.empty
        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 28
Source File: LogPageSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 29
Source File: MasterWebUISuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
}