org.mockito.Mockito.when Scala Examples

The following examples show how to use org.mockito.Mockito.when. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TestOracleDataTypeConverter.scala    From ohara   with Apache License 2.0 7 votes vote down vote up
package oharastream.ohara.connector.jdbc.datatype

import java.sql.ResultSet

import oharastream.ohara.client.configurator.InspectApi.RdbColumn
import oharastream.ohara.common.rule.OharaTest
import org.junit.Test
import org.mockito.Mockito
import org.mockito.Mockito.when
import org.scalatest.matchers.should.Matchers._

class TestOracleDataTypeConverter extends OharaTest {
  @Test
  def testConverterCharValue(): Unit = {
    val resultSet: ResultSet = Mockito.mock(classOf[ResultSet])
    when(resultSet.getString("column1")).thenReturn("value1")
    val column                  = RdbColumn("column1", "CHAR", false)
    val oracleDataTypeConverter = new OracleDataTypeConverter()
    val result                  = oracleDataTypeConverter.converterValue(resultSet, column)
    result shouldBe "value1"
    result.isInstanceOf[String] shouldBe true
  }

  @Test
  def testConverterRawValue(): Unit = {
    val resultSet: ResultSet = Mockito.mock(classOf[ResultSet])
    when(resultSet.getBytes("column1")).thenReturn("aaaa".getBytes)
    val column                  = RdbColumn("column1", "RAW", false)
    val oracleDataTypeConverter = new OracleDataTypeConverter()
    val result                  = oracleDataTypeConverter.converterValue(resultSet, column)
    result.isInstanceOf[Array[Byte]] shouldBe true
    new String(result.asInstanceOf[Array[Byte]]) shouldBe "aaaa"
  }

  @Test
  def testConverterRawNullValue(): Unit = {
    val resultSet: ResultSet = Mockito.mock(classOf[ResultSet])
    when(resultSet.getBytes("column1")).thenReturn(null)
    val column                  = RdbColumn("column1", "RAW", false)
    val oracleDataTypeConverter = new OracleDataTypeConverter()
    val result                  = oracleDataTypeConverter.converterValue(resultSet, column)
    result.isInstanceOf[Array[Byte]] shouldBe true
    result.asInstanceOf[Array[Byte]].length shouldBe 0
  }

  @Test
  def testConverterSmallIntValue(): Unit = {
    val resultSet: ResultSet = Mockito.mock(classOf[ResultSet])
    when(resultSet.getInt("column1")).thenReturn(111)
    val column                  = RdbColumn("column1", "INT", false)
    val oracleDataTypeConverter = new OracleDataTypeConverter()
    val result                  = oracleDataTypeConverter.converterValue(resultSet, column)
    result.isInstanceOf[Integer] shouldBe true
    result.asInstanceOf[Integer] shouldBe 111
  }
} 
Example 2
Source File: OffsetLoaderTest.scala    From toketi-iothubreact   with MIT License 6 votes vote down vote up
package com.microsoft.azure.iot.iothubreact.checkpointing

import com.microsoft.azure.iot.iothubreact.config.{IConfiguration, IConnectConfiguration}
import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar
import org.mockito.Mockito.when
import org.scalatest.Matchers._

class OffsetLoaderTest extends FunSuite with MockitoSugar {

  test("test GetSavedOffsets handles None appropriately") {

    val config = mock[IConfiguration]
    val cnConfig = mock[IConnectConfiguration]
    when(config.connect) thenReturn(cnConfig)
    when(cnConfig.iotHubPartitions) thenReturn(10)
    val loader = StubbedLoader(config)
    loader.GetSavedOffsets should be(Map(0 → "Offset 0", 1 → "Offset 1", 3 → "Offset 3"))
  }

  case class StubbedLoader(config: IConfiguration) extends OffsetLoader(config) {

    override private[iothubreact] def GetSavedOffset(partition: Int) = {
      partition match {
        case 0 ⇒ Some("Offset 0")
        case 1 ⇒ Some("Offset 1")
        case 3 ⇒ Some("Offset 3")
        case _ ⇒ None
      }
    }
  }

} 
Example 3
Source File: PersonalDetailsControllerSpec.scala    From pertax-frontend   with Apache License 2.0 6 votes vote down vote up
package controllers.address

import config.ConfigDecorator
import controllers.auth.requests.UserRequest
import controllers.auth.{AuthJourney, WithActiveTabAction}
import controllers.controllershelpers.{AddressJourneyCachingHelper, PersonalDetailsCardGenerator}
import models.AddressJourneyTTLModel
import models.dto.AddressPageVisitedDto
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.{times, verify, when}
import org.mockito.Matchers.{eq => meq, _}
import org.scalatestplus.mockito.MockitoSugar
import play.api.http.Status.OK
import play.api.libs.json.Json
import play.api.mvc.{MessagesControllerComponents, Request, Result}
import play.api.test.FakeRequest
import repositories.EditAddressLockRepository
import services.{LocalSessionCache, NinoDisplayService}
import uk.gov.hmrc.http.cache.client.CacheMap
import uk.gov.hmrc.play.audit.http.connector.{AuditConnector, AuditResult}
import uk.gov.hmrc.play.audit.model.DataEvent
import uk.gov.hmrc.renderer.TemplateRenderer
import util.UserRequestFixture.buildUserRequest
import util.{ActionBuilderFixture, BaseSpec, Fixtures, LocalPartialRetriever}
import views.html.interstitial.DisplayAddressInterstitialView
import views.html.personaldetails.{AddressAlreadyUpdatedView, CannotUseServiceView, PersonalDetailsView}

import scala.concurrent.{ExecutionContext, Future}

class PersonalDetailsControllerSpec extends AddressBaseSpec {

  val ninoDisplayService = mock[NinoDisplayService]

  trait LocalSetup extends AddressControllerSetup {

    when(ninoDisplayService.getNino(any(), any())).thenReturn {
      Future.successful(Some(Fixtures.fakeNino))
    }

    def currentRequest[A]: Request[A] = FakeRequest().asInstanceOf[Request[A]]

    def controller =
      new PersonalDetailsController(
        injected[PersonalDetailsCardGenerator],
        mockEditAddressLockRepository,
        ninoDisplayService,
        mockAuthJourney,
        addressJourneyCachingHelper,
        withActiveTabAction,
        mockAuditConnector,
        cc,
        displayAddressInterstitialView,
        injected[PersonalDetailsView]
      ) {}

    "Calling AddressController.onPageLoad" should {

      "call citizenDetailsService.fakePersonDetails and return 200" in new LocalSetup {
        override def sessionCacheResponse: Option[CacheMap] =
          Some(CacheMap("id", Map("addressPageVisitedDto" -> Json.toJson(AddressPageVisitedDto(true)))))

        val result = controller.onPageLoad()(FakeRequest())

        status(result) shouldBe OK
        verify(mockLocalSessionCache, times(1))
          .cache(meq("addressPageVisitedDto"), meq(AddressPageVisitedDto(true)))(any(), any(), any())
        verify(mockEditAddressLockRepository, times(1)).get(any())
      }

      "send an audit event when user arrives on personal details page" in new LocalSetup {
        override def sessionCacheResponse: Option[CacheMap] =
          Some(CacheMap("id", Map("addressPageVisitedDto" -> Json.toJson(AddressPageVisitedDto(true)))))

        val result = controller.onPageLoad()(FakeRequest())
        val eventCaptor = ArgumentCaptor.forClass(classOf[DataEvent])

        status(result) shouldBe OK
        verify(mockAuditConnector, times(1)).sendEvent(eventCaptor.capture())(any(), any())
      }
    }
  }
} 
Example 4
Source File: DiskBlockManagerSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 5
Source File: ExecuteStatementInClientModeWithHDFSSuite.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up
package yaooqinn.kyuubi.operation

import java.io.{File, IOException}

import scala.util.Try

import org.apache.hadoop.fs.Path
import org.apache.hadoop.hdfs.{HdfsConfiguration, MiniDFSCluster}
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.sql.catalyst.catalog.FunctionResource
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.mockito.Mockito.when

import yaooqinn.kyuubi.operation.statement.ExecuteStatementInClientMode
import yaooqinn.kyuubi.utils.{KyuubiHiveUtil, ReflectUtils}

class ExecuteStatementInClientModeWithHDFSSuite extends ExecuteStatementInClientModeSuite {
  val hdfsConf = new HdfsConfiguration
  hdfsConf.set("fs.hdfs.impl.disable.cache", "true")
  var cluster: MiniDFSCluster = new MiniDFSCluster.Builder(hdfsConf).build()
  cluster.waitClusterUp()
  val fs = cluster.getFileSystem
  val homeDirectory: Path = fs.getHomeDirectory
  private val fileName = "example-1.0.0-SNAPSHOT.jar"
  private val remoteUDFFile = new Path(homeDirectory, fileName)

  override def beforeAll(): Unit = {
    val file = new File(this.getClass.getProtectionDomain.getCodeSource.getLocation + fileName)
    val localUDFFile = new Path(file.getPath)
    fs.copyFromLocalFile(localUDFFile, remoteUDFFile)
    super.beforeAll()
  }

  override def afterAll(): Unit = {
    fs.delete(remoteUDFFile, true)
    fs.close()
    cluster.shutdown()
    super.afterAll()
  }

  test("transform logical plan") {
    val op = sessionMgr.getOperationMgr.newExecuteStatementOperation(session, statement)
      .asInstanceOf[ExecuteStatementInClientMode]
    val parser = new SparkSqlParser(new SQLConf)
    val plan0 = parser.parsePlan(
      s"create temporary function a as 'a.b.c' using file '$remoteUDFFile'")
    val plan1 = op.transform(plan0)
    assert(plan0 === plan1)
    assert(
      ReflectUtils.getFieldValue(plan1, "resources").asInstanceOf[Seq[FunctionResource]].isEmpty)

    val plan2 = parser.parsePlan(
      s"create temporary function a as 'a.b.c' using jar '$remoteUDFFile'")
    val plan3 = op.transform(plan2)
    assert(plan3 === plan2)
    assert(
      ReflectUtils.getFieldValue(plan3, "resources").asInstanceOf[Seq[FunctionResource]].isEmpty)
  }

  test("add delegation token with hive session state, hdfs") {
    val hiveConf = new HiveConf(hdfsConf, classOf[HiveConf])
    val state = new SessionState(hiveConf)
    assert(Try {
      KyuubiHiveUtil.addDelegationTokensToHiveState(state, UserGroupInformation.getCurrentUser)
    }.isSuccess)

    val mockuser = mock[UserGroupInformation]
    when(mockuser.getUserName).thenThrow(classOf[IOException])
    KyuubiHiveUtil.addDelegationTokensToHiveState(state, mockuser)
  }
} 
Example 6
Source File: HeartbeatReceiverSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration._
import scala.language.postfixOps

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
import org.mockito.Mockito.{mock, spy, verify, when}
import org.mockito.Matchers
import org.mockito.Matchers._

import org.apache.spark.scheduler.TaskScheduler
import org.apache.spark.util.RpcUtils
import org.scalatest.concurrent.Eventually._

class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext {

  test("HeartbeatReceiver") {
    sc = spy(new SparkContext("local[2]", "test"))
    val scheduler = mock(classOf[TaskScheduler])
    when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
    when(sc.taskScheduler).thenReturn(scheduler)

    val heartbeatReceiver = new HeartbeatReceiver(sc)
    sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
    eventually(timeout(5 seconds), interval(5 millis)) {
      assert(heartbeatReceiver.scheduler != null)
    }
    val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)

    val metrics = new TaskMetrics
    val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
    val response = receiverRef.askWithRetry[HeartbeatResponse](
      Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))

    verify(scheduler).executorHeartbeatReceived(
      Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
    assert(false === response.reregisterBlockManager)
  }

  test("HeartbeatReceiver re-register") {
    sc = spy(new SparkContext("local[2]", "test"))
    val scheduler = mock(classOf[TaskScheduler])
    when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
    when(sc.taskScheduler).thenReturn(scheduler)

    val heartbeatReceiver = new HeartbeatReceiver(sc)
    sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
    eventually(timeout(5 seconds), interval(5 millis)) {
      assert(heartbeatReceiver.scheduler != null)
    }
    val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)

    val metrics = new TaskMetrics
    val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
    val response = receiverRef.askWithRetry[HeartbeatResponse](
      Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))

    verify(scheduler).executorHeartbeatReceived(
      Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
    assert(true === response.reregisterBlockManager)
  }
} 
Example 7
Source File: LogPageSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 8
Source File: PipelineSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml

import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame

class PipelineSuite extends SparkFunSuite {

  abstract class MyModel extends Model[MyModel]

  test("pipeline") {
    val estimator0 = mock[Estimator[MyModel]]
    val model0 = mock[MyModel]
    val transformer1 = mock[Transformer]
    val estimator2 = mock[Estimator[MyModel]]
    val model2 = mock[MyModel]
    val transformer3 = mock[Transformer]
    val dataset0 = mock[DataFrame]
    val dataset1 = mock[DataFrame]
    val dataset2 = mock[DataFrame]
    val dataset3 = mock[DataFrame]
    val dataset4 = mock[DataFrame]

    when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
    when(model0.copy(any[ParamMap])).thenReturn(model0)
    when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
    when(estimator2.copy(any[ParamMap])).thenReturn(estimator2)
    when(model2.copy(any[ParamMap])).thenReturn(model2)
    when(transformer3.copy(any[ParamMap])).thenReturn(transformer3)

    when(estimator0.fit(meq(dataset0))).thenReturn(model0)
    when(model0.transform(meq(dataset0))).thenReturn(dataset1)
    when(model0.parent).thenReturn(estimator0)
    when(transformer1.transform(meq(dataset1))).thenReturn(dataset2)
    when(estimator2.fit(meq(dataset2))).thenReturn(model2)
    when(model2.transform(meq(dataset2))).thenReturn(dataset3)
    when(model2.parent).thenReturn(estimator2)
    when(transformer3.transform(meq(dataset3))).thenReturn(dataset4)

    val pipeline = new Pipeline()
      .setStages(Array(estimator0, transformer1, estimator2, transformer3))
    val pipelineModel = pipeline.fit(dataset0)

    assert(pipelineModel.stages.length === 4)
    assert(pipelineModel.stages(0).eq(model0))
    assert(pipelineModel.stages(1).eq(transformer1))
    assert(pipelineModel.stages(2).eq(model2))
    assert(pipelineModel.stages(3).eq(transformer3))

    val output = pipelineModel.transform(dataset0)
    assert(output.eq(dataset4))
  }

  test("pipeline with duplicate stages") {
    val estimator = mock[Estimator[MyModel]]
    val pipeline = new Pipeline()
      .setStages(Array(estimator, estimator))
    val dataset = mock[DataFrame]
    intercept[IllegalArgumentException] {
      pipeline.fit(dataset)
    }
  }

  test("PipelineModel.copy") {
    val hashingTF = new HashingTF()
      .setNumFeatures(100)
    val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
    val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
    require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
      "copy should handle extra stage params")
  }
} 
Example 9
Source File: StagePageSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{RETURNS_SMART_NULLS, mock, when}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.exec.ExecutorsListener
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener
import org.apache.spark.util.Utils

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  private val peakExecutionMemory = 10

  test("peak execution memory should displayed") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {

    val jobListener = new JobProgressListener(conf, Utils.getCurrentUserName())
    val graphListener = new RDDOperationGraphListener(conf)
    val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.executorsListener).thenReturn(executorsListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markFinished(TaskState.FINISHED)
        val taskMetrics = TaskMetrics.empty
        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 10
Source File: LogPageSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 11
Source File: MasterWebUISuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 12
Source File: JdbcConnectorTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.connectors

import java.sql.{Connection, DatabaseMetaData, ResultSet, Statement}

import com.aol.one.dwh.infra.config._
import com.aol.one.dwh.infra.sql.pool.HikariConnectionPool
import com.aol.one.dwh.infra.sql.{ListStringResultHandler, Setting, VerticaMaxValuesQuery}
import org.apache.commons.dbutils.ResultSetHandler
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class JdbcConnectorTest extends FunSuite with MockitoSugar {

  private val statement = mock[Statement]
  private val resultSet = mock[ResultSet]
  private val connectionPool = mock[HikariConnectionPool]
  private val connection = mock[Connection]
  private val databaseMetaData = mock[DatabaseMetaData]
  private val resultSetHandler = mock[ResultSetHandler[Long]]
  private val listStringResultHandler = mock[ListStringResultHandler]

  test("check run query result for numeric batch_id column") {
    val resultValue = 100L
    val table = Table("table", List("column"), None)
    val query = VerticaMaxValuesQuery(table)
    when(connectionPool.getConnection).thenReturn(connection)
    when(connectionPool.getName).thenReturn("connection_pool_name")
    when(connection.createStatement()).thenReturn(statement)
    when(statement.executeQuery("SELECT MAX(column) AS column FROM table")).thenReturn(resultSet)
    when(connection.getMetaData).thenReturn(databaseMetaData)
    when(databaseMetaData.getURL).thenReturn("connection_url")
    when(resultSetHandler.handle(resultSet)).thenReturn(resultValue)

    val result = new DefaultJdbcConnector(connectionPool).runQuery(query, resultSetHandler)

    assert(result == resultValue)
  }

  test("check run query result for date/time partitions") {
    val resultValue = Some(20190924L)
    val table = Table("table", List("year", "month", "day"), Some(List("yyyy", "MM", "dd")))
    val query = VerticaMaxValuesQuery(table)
    when(connectionPool.getConnection).thenReturn(connection)
    when(connectionPool.getName).thenReturn("connection_pool_name")
    when(connection.createStatement()).thenReturn(statement)
    when(statement.executeQuery("SELECT DISTINCT year, month, day FROM table")).thenReturn(resultSet)
    when(connection.getMetaData).thenReturn(databaseMetaData)
    when(databaseMetaData.getURL).thenReturn("connection_url")
    when(listStringResultHandler.handle(resultSet)).thenReturn(resultValue)

    val result = new DefaultJdbcConnector(connectionPool).runQuery(query, listStringResultHandler)

    assert(result == resultValue)
  }
}

class DefaultJdbcConnector(connectionPool: HikariConnectionPool) extends JdbcConnector(connectionPool) {
  override def applySetting(connection: Connection, statement: Statement, setting: Setting): Unit = {}
} 
Example 13
Source File: GlueConnectorTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.connectors
import com.aol.one.dwh.infra.config._
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

import scala.concurrent.duration._

class GlueConnectorTest extends FunSuite with MockitoSugar {

  private val config = GlueConfig("eu-central-1", "default", "accessKey", "secretKey", 5, 2, 10.seconds)
  private val glueConnector = mock[GlueConnector]

  test("Check max batchId from glue metadata tables") {
    val resultValue = 100L
    val numericTable = Table("table", List("column"), None)
    when(glueConnector.getMaxPartitionValue(numericTable)).thenReturn(resultValue)

    val result = glueConnector.getMaxPartitionValue(numericTable)

    assert(result == resultValue)
  }

  test("Check max date partitions' value from glue metadata table") {
    val resultValue = 15681377656L
    val datetimeTable = Table("table", List("year", "month", "day"), Some(List("yyyy", "MM", "dd")))
    when(glueConnector.getMaxPartitionValue(datetimeTable)).thenReturn(resultValue)

    val result = glueConnector.getMaxPartitionValue(datetimeTable)

    assert(result == resultValue)
  }
} 
Example 14
Source File: KafkaInMessagesProviderTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.providers

import com.aol.one.dwh.bandarlog.connectors.KafkaConnector
import com.aol.one.dwh.infra.config.Topic
import kafka.common.TopicAndPartition
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar


class KafkaInMessagesProviderTest extends FunSuite with MockitoSugar {

  private val kafkaConnector = mock[KafkaConnector]
  private val topic = Topic("topic_id", Set("topic_1", "topic_2"), "group_id")

  test("check count of in messages/heads over all topic partitions") {
    val heads = Some(Map(
      TopicAndPartition("topic_1", 1) -> 1L,
      TopicAndPartition("topic_2", 2) -> 2L,
      TopicAndPartition("topic_3", 3) -> 3L
    ))
    when(kafkaConnector.getHeads(topic)).thenReturn(heads)

    val result = new KafkaInMessagesProvider(kafkaConnector, topic).provide()

    assert(result.getValue.nonEmpty)
    assert(result.getValue.get == 6) // 1 + 2 + 3
  }

  test("check count of in messages/heads for empty heads result") {
    when(kafkaConnector.getHeads(topic)).thenReturn(Some(Map[TopicAndPartition, Long]()))

    val result = new KafkaInMessagesProvider(kafkaConnector, topic).provide()

    assert(result.getValue.nonEmpty)
    assert(result.getValue.get == 0)
  }

  test("return none if can't retrieve heads") {
    when(kafkaConnector.getHeads(topic)).thenReturn(None)

    val result = new KafkaInMessagesProvider(kafkaConnector, topic).provide()

    assert(result.getValue.isEmpty)
  }
} 
Example 15
Source File: KafkaLagProviderTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.providers

import com.aol.one.dwh.infra.config.Topic
import com.aol.one.dwh.bandarlog.connectors.KafkaConnector
import kafka.common.TopicAndPartition
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar


class KafkaLagProviderTest extends FunSuite with MockitoSugar {

  private val kafkaConnector = mock[KafkaConnector]
  private val topic = Topic("topic_id", Set("topic_1", "topic_2", "topic_3"), "group_id")

  test("check lag per topic") {
    val heads = Map(
      TopicAndPartition("topic_1", 1) -> 4L,
      TopicAndPartition("topic_2", 2) -> 5L,
      TopicAndPartition("topic_3", 3) -> 6L
    )

    val offsets = Map(
      TopicAndPartition("topic_1", 1) -> 1L,
      TopicAndPartition("topic_2", 2) -> 2L,
      TopicAndPartition("topic_3", 3) -> 3L
    )
    val kafkaState = Option((heads, offsets))
    when(kafkaConnector.getKafkaState(topic)).thenReturn(kafkaState)

    val result = new KafkaLagProvider(kafkaConnector, topic).provide()

    // topic       partition  heads  offsets  lag
    // topic_1     1          4      1        4-1=3
    // topic_2     2          5      2        5-2=3
    // topic_3     3          6      3        6-3=3
    assert(result.getValue.nonEmpty)
    assert(result.getValue.get == 9) // lag sum 3 + 3 + 3
  }

  test("check 0 lag case per topic") {
    val heads = Map(
      TopicAndPartition("topic_1", 1) -> 1L,
      TopicAndPartition("topic_2", 2) -> 2L,
      TopicAndPartition("topic_3", 3) -> 3L
    )

    val offsets = Map(
      TopicAndPartition("topic_1", 1) -> 4L,
      TopicAndPartition("topic_2", 2) -> 5L,
      TopicAndPartition("topic_3", 3) -> 6L
    )
    val kafkaState = Option((heads, offsets))
    when(kafkaConnector.getKafkaState(topic)).thenReturn(kafkaState)

    val result = new KafkaLagProvider(kafkaConnector, topic).provide()

    // topic       partition  heads  offsets  lag
    // topic_1     1          1      4        1-4= -3
    // topic_2     2          2      5        2-5= -3
    // topic_3     3          3      6        3-6= -3
    assert(result.getValue.nonEmpty)
    assert(result.getValue.get == 0) // lag.max(0) = 0
  }

  test("check lag for empty heads and offsets") {
    val kafkaState = Option((Map[TopicAndPartition, Long](), Map[TopicAndPartition, Long]()))
    when(kafkaConnector.getKafkaState(topic)).thenReturn(kafkaState)

    val result = new KafkaLagProvider(kafkaConnector, topic).provide()

    assert(result.getValue.nonEmpty)
    assert(result.getValue.get == 0)
  }

  test("return none if can't retrieve kafka state") {
    when(kafkaConnector.getKafkaState(topic)).thenReturn(None)

    val result = new KafkaLagProvider(kafkaConnector, topic).provide()

    assert(result.getValue.isEmpty)
  }
} 
Example 16
Source File: KafkaOutMessagesProviderTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.providers

import com.aol.one.dwh.bandarlog.connectors.KafkaConnector
import com.aol.one.dwh.infra.config.Topic
import kafka.common.TopicAndPartition
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar


class KafkaOutMessagesProviderTest extends FunSuite with MockitoSugar {

  private val kafkaConnector = mock[KafkaConnector]
  private val topic = Topic("topic_id", Set("topic_1", "topic_2"), "group_id")

  test("check count of out messages/offsets over all topic partitions") {
    val offsets = Option(Map(
      TopicAndPartition("topic_1", 1) -> 1L,
      TopicAndPartition("topic_2", 2) -> 2L,
      TopicAndPartition("topic_3", 3) -> 3L
    ))
    when(kafkaConnector.getOffsets(topic)).thenReturn(offsets)

    val result = new KafkaOutMessagesProvider(kafkaConnector, topic).provide()

    assert(result.getValue.nonEmpty)
    assert(result.getValue.get == 6) // 1 + 2 + 3
  }

  test("check count of out messages/offsets for empty offsets result") {
    when(kafkaConnector.getOffsets(topic)).thenReturn(Some(Map[TopicAndPartition, Long]()))

    val result = new KafkaOutMessagesProvider(kafkaConnector, topic).provide()

    assert(result.getValue.nonEmpty)
    assert(result.getValue.get == 0)
  }

  test("return none if can't retrieve offsets") {
    when(kafkaConnector.getOffsets(topic)).thenReturn(None)

    val result = new KafkaOutMessagesProvider(kafkaConnector, topic).provide()

    assert(result.getValue.isEmpty)
  }
} 
Example 17
Source File: SqlLagProviderTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.providers

import com.aol.one.dwh.bandarlog.metrics.AtomicValue
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar


class SqlLagProviderTest extends FunSuite with MockitoSugar {

  private val fromProvider = mock[SqlTimestampProvider]
  private val toProvider = mock[SqlTimestampProvider]
  private val toGlueProvider = mock[GlueTimestampProvider]
  private val lagProvider1 = new SqlLagProvider(fromProvider, toProvider)
  private val lagProvider2 = new SqlLagProvider(fromProvider, toGlueProvider)

  test("check lag between from and to providers") {
    val fromValue = AtomicValue(Some(7L))
    val toValue = AtomicValue(Some(4L))
    val toGlueValue = AtomicValue(Some(6L))

    when(fromProvider.provide()).thenReturn(fromValue)
    when(toProvider.provide()).thenReturn(toValue)
    when(toGlueProvider.provide()).thenReturn(toGlueValue)

    val lag1 = lagProvider1.provide()
    val lag2 = lagProvider2.provide()

    assert(lag1.getValue.nonEmpty)
    assert(lag1.getValue.get == 3)
    assert(lag2.getValue.nonEmpty)
    assert(lag2.getValue.get == 1)
  }

  test("return none if 'from provider' value is none") {
    val toValue = AtomicValue(Some(4L))

    when(fromProvider.provide()).thenReturn(AtomicValue[Long](None))
    when(toProvider.provide()).thenReturn(toValue)

    val lag = lagProvider1.provide()

    assert(lag.getValue.isEmpty)
  }

  test("return none if 'to provider' value is none") {
    val fromValue = AtomicValue(Some(7L))

    when(fromProvider.provide()).thenReturn(fromValue)
    when(toProvider.provide()).thenReturn(AtomicValue[Long](None))
    when(toGlueProvider.provide()).thenReturn(AtomicValue[Long](None))

    val lag1 = lagProvider1.provide()
    val lag2 = lagProvider2.provide()

    assert(lag1.getValue.isEmpty)
    assert(lag2.getValue.isEmpty)
  }

  test("return none if both providers values is none") {
    when(fromProvider.provide()).thenReturn(AtomicValue[Long](None))
    when(toProvider.provide()).thenReturn(AtomicValue[Long](None))
    when(toGlueProvider.provide()).thenReturn(AtomicValue[Long](None))

    val lag1 = lagProvider1.provide()
    val lag2 = lagProvider2.provide()

    assert(lag1.getValue.isEmpty)
    assert(lag2.getValue.isEmpty)
  }
} 
Example 18
Source File: GlueTimestampProviderTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.providers

import com.aol.one.dwh.bandarlog.connectors.GlueConnector
import com.aol.one.dwh.infra.config.Table
import org.mockito.Matchers.any
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class GlueTimestampProviderTest extends FunSuite with MockitoSugar{

  private val table = mock[Table]
  private val glueConnector = mock[GlueConnector]
  private val glueTimestampProvider = new GlueTimestampProvider(glueConnector, table)

  test("check timestamp value by glue connector and table") {
    val glueTimestamp = 1533709910004L
    when(glueConnector.getMaxPartitionValue(any())).thenReturn(glueTimestamp)

    val result = glueTimestampProvider.provide()

    assert(result.getValue == Some(glueTimestamp))
  }

  test("return zero if partition column does not have values") {
    when(glueConnector.getMaxPartitionValue(any())).thenReturn(0)

    val result = glueTimestampProvider.provide()

    assert(result.getValue == Some(0))
  }
} 
Example 19
Source File: HttpVerbSpec.scala    From http-verbs   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.http

import java.util

import com.typesafe.config.Config
import org.mockito.Matchers.any
import org.mockito.Mockito.when
import org.scalatest.LoneElement
import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.wordspec.AnyWordSpecLike
import org.scalatest.matchers.should.Matchers
import uk.gov.hmrc.http.logging.{Authorization, ForwardedFor, RequestId, SessionId}

class HttpVerbSpec extends AnyWordSpecLike with Matchers with MockitoSugar with LoneElement {

  "applicableHeaders" should {

    "should contain the values passed in by header-carrier" in {
      val url = "http://test.me"

      implicit val hc = HeaderCarrier(
        authorization = Some(Authorization("auth")),
        sessionId     = Some(SessionId("session")),
        requestId     = Some(RequestId("request")),
        token         = Some(Token("token")),
        forwarded     = Some(ForwardedFor("forwarded"))
      )

      val httpRequest = new HttpVerb {
        override def configuration: Option[Config] = None
      }
      val result = httpRequest.applicableHeaders(url)

      result shouldBe hc.headers
    }

    "should include the User-Agent header when the 'appName' config value is present" in {

      val mockedConfig = mock[Config]
      when(mockedConfig.getStringList(any())).thenReturn(new util.ArrayList[String]())
      when(mockedConfig.getString("appName")).thenReturn("myApp")
      when(mockedConfig.hasPathOrNull("appName")).thenReturn(true)

      val httpRequest = new HttpVerb {
        override def configuration: Option[Config] = Some(mockedConfig)
      }
      val result = httpRequest.applicableHeaders("http://test.me")(HeaderCarrier())

      result.contains("User-Agent" -> "myApp") shouldBe true
    }

    "filter 'remaining headers' from request for external service calls" in {

      implicit val hc = HeaderCarrier(
        otherHeaders = Seq("foo" -> "secret!")
      )

      val httpRequest = new HttpVerb {
        override def configuration: Option[Config] = None
      }
      val result = httpRequest.applicableHeaders("http://test.me")
      result.map(_._1) should not contain "foo"
    }

    "include 'remaining headers' in request for internal service call to .service URL" in {
      implicit val hc = HeaderCarrier(
        otherHeaders = Seq("foo" -> "secret!")
      )
      val httpRequest = new HttpVerb {
        override def configuration: Option[Config] = None
      }

      for { url <- List("http://test.public.service/bar", "http://test.public.mdtp/bar") } {

        val result = httpRequest.applicableHeaders(url)
        assert(result.contains("foo" -> "secret!"), s"'other/remaining headers' for $url were not present")

      }
    }

    "include 'remaining headers' in request for internal service call to other configured internal URL pattern" in {
      val url = "http://localhost/foo" // an internal service call, according to config
      implicit val hc = HeaderCarrier(
        otherHeaders = Seq("foo" -> "secret!")
      )

      import scala.collection.JavaConversions._
      val mockedConfig = mock[Config]
      when(mockedConfig.getStringList("internalServiceHostPatterns")).thenReturn(List("localhost"))
      when(mockedConfig.hasPathOrNull("internalServiceHostPatterns")).thenReturn(true)

      val httpRequest = new HttpVerb {
        override def configuration: Option[Config] = Some(mockedConfig)
      }
      val result = httpRequest.applicableHeaders(url)
      result.contains("foo" -> "secret!") shouldBe true
    }

  }

} 
Example 20
Source File: KyuubiDistributedCacheManagerSuite.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import java.net.URI

import scala.collection.mutable.{HashMap, Map}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.yarn.api.records.{LocalResource, LocalResourceType, LocalResourceVisibility}
import org.apache.hadoop.yarn.util.ConverterUtils
import org.apache.spark.{KyuubiSparkUtil, SparkFunSuite}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar

import yaooqinn.kyuubi.utils.ReflectUtils

class KyuubiDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar {

  class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
    override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
    LocalResourceVisibility = {
      LocalResourceVisibility.PRIVATE
    }
  }

  test("add resource") {
    val fs = mock[FileSystem]
    val conf = new Configuration()
    val destPath = new Path("file:///foo.bar.com:8080/tmp/testing")
    val localResources = HashMap[String, LocalResource]()
    val statCache = HashMap[URI, FileStatus]()
    val status = new FileStatus()
    when(fs.getFileStatus(destPath)).thenReturn(status)
    val fileLink = "link"
    ReflectUtils.setFieldValue(
      KyuubiDistributedCacheManager, "cacheManager", new MockClientDistributedCacheManager)
    KyuubiDistributedCacheManager.addResource(
      fs, conf, destPath, localResources, LocalResourceType.FILE, fileLink, statCache)
    val res = localResources(fileLink)
    assert(res.getVisibility === LocalResourceVisibility.PRIVATE)
    assert(ConverterUtils.getPathFromYarnURL(res.getResource) === destPath)
    assert(res.getSize === 0)
    assert(res.getTimestamp === 0)
    assert(res.getType === LocalResourceType.FILE)
    val status2 = new FileStatus(
      10, false, 1, 1024, 10,
      10, null, KyuubiSparkUtil.getCurrentUserName, null, new Path("/tmp/testing2"))
    val destPath2 = new Path("file:///foo.bar.com:8080/tmp/testing2")
    when(fs.getFileStatus(destPath2)).thenReturn(status2)
    val fileLink2 = "link2"
    KyuubiDistributedCacheManager.addResource(
      fs, conf, destPath2, localResources, LocalResourceType.FILE, fileLink2, statCache)
    val res2 = localResources(fileLink2)
    assert(res2.getVisibility === LocalResourceVisibility.PRIVATE)
    assert(ConverterUtils.getPathFromYarnURL(res2.getResource) === destPath2)
    assert(res2.getSize === 10)
    assert(res2.getTimestamp === 10)
    assert(res2.getType === LocalResourceType.FILE)
  }

  test("add resource when link null") {
    val distMgr = new MockClientDistributedCacheManager()
    val fs = mock[FileSystem]
    val conf = new Configuration()
    val destPath = new Path("file:///foo.bar.com:8080/tmp/testing")
    ReflectUtils.setFieldValue(KyuubiDistributedCacheManager, "cacheManager", distMgr)
    val localResources = HashMap[String, LocalResource]()
    val statCache = HashMap[URI, FileStatus]()
    when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
    intercept[Exception] {
      KyuubiDistributedCacheManager.addResource(
        fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache)
    }
    assert(localResources.get("link") === None)
    assert(localResources.size === 0)
  }

  test("test addResource archive") {
    val distMgr = new MockClientDistributedCacheManager()
    ReflectUtils.setFieldValue(KyuubiDistributedCacheManager, "cacheManager", distMgr)
    val fs = mock[FileSystem]
    val conf = new Configuration()
    val destPath = new Path("file:///foo.bar.com:8080/tmp/testing")
    val localResources = HashMap[String, LocalResource]()
    val statCache = HashMap[URI, FileStatus]()
    val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
      null, new Path("/tmp/testing"))
    when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)

    KyuubiDistributedCacheManager.addResource(
      fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache)
    val resource = localResources("link")
    assert(resource.getVisibility === LocalResourceVisibility.PRIVATE)
    assert(ConverterUtils.getPathFromYarnURL(resource.getResource) === destPath)
    assert(resource.getTimestamp === 10)
    assert(resource.getSize === 10)
    assert(resource.getType === LocalResourceType.ARCHIVE)

  }

} 
Example 21
Source File: KyuubiSessionSubPageSuite.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.util.Try

import org.apache.spark.{KyuubiSparkUtil, SparkConf, SparkContext, SparkFunSuite}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar

import yaooqinn.kyuubi.ui.{ExecutionInfo, KyuubiServerListener, SessionInfo}

class KyuubiSessionSubPageSuite extends SparkFunSuite with MockitoSugar {

  var sc: SparkContext = _
  var user: String = _
  var tab: KyuubiSessionTab = _

  override def beforeAll(): Unit = {
    val conf = new SparkConf(loadDefaults = true).setMaster("local").setAppName("test")
    sc = new SparkContext(conf)
    user = KyuubiSparkUtil.getCurrentUserName
    tab = new KyuubiSessionTab(user, sc)
  }

  override def afterAll(): Unit = {
    sc.stop()
  }

  test("render kyuubi session page") {
    val page = new KyuubiSessionSubPage(tab)

    val request = mock[HttpServletRequest]
    intercept[IllegalArgumentException](page.render(request))

    val id = "id1"
    when(request.getParameter("id")).thenReturn(id)
    intercept[IllegalArgumentException](page.render(request))

    val sessionInfo = mock[SessionInfo]
    val tab1 = mock[KyuubiSessionTab]
    when(request.getParameter("id")).thenReturn(id)
    val listener = mock[KyuubiServerListener]
    when(tab1.listener).thenReturn(listener)
    when(listener.getSession(id)).thenReturn(Some(sessionInfo))
    when(sessionInfo.sessionId).thenReturn("1")
    when(listener.getExecutionList).thenReturn(Seq[ExecutionInfo]())
    when(tab1.appName).thenReturn("name")
    when(tab1.headerTabs).thenReturn(Seq[WebUITab]())
    val page2 = new KyuubiSessionSubPage(tab1)
    assert(Try { page2.render(request) }.isSuccess )
  }
} 
Example 22
Source File: StagePageSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {
  //仅在启用不安全时才显示执行内存值
  test("peak execution memory only displayed if unsafe is enabled") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    println("===="+html)
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
    // Disable unsafe and make sure it's not there
    //禁用不安全的,并确保它不在那里
    val conf2 = new SparkConf(false).set(unsafeConf, "false")
    val html2 = renderStagePage(conf2).toString().toLowerCase
    assert(!html2.contains(targetString))
    // Avoid setting anything; it should be displayed by default
    //避免设置任何东西,它应该默认显示
    val conf3 = new SparkConf(false)
    val html3 = renderStagePage(conf3).toString().toLowerCase
    assert(html3.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    //验证min / 25/50/75 / max显示任务值不是累积值
    assert(html.contains("<td>10.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    //在工作进度侦听器中模拟一个阶段
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        val peakExecutionMemory = 10
        taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY,
          Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markSuccessful()
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 23
Source File: DiskBlockManagerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils
//DiskBlockManager管理和维护了逻辑上的Block和存储在Disk上的物理的Block的映射。
//一般来说,一个逻辑的Block会根据它的BlockId生成的名字映射到一个物理上的文件
class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {//基本块的创建
    val blockId = new TestBlockId("test")
    //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {//枚举块
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 24
Source File: JsonEncoderSpec.scala    From logback-json-logger   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.play.logging
import java.io.{PrintWriter, StringWriter}
import java.net.InetAddress

import ch.qos.logback.classic.Level
import ch.qos.logback.classic.spi.{ILoggingEvent, ThrowableProxy}
import ch.qos.logback.core.ContextBase
import org.apache.commons.lang3.time.FastDateFormat
import org.mockito.Mockito.when
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import org.scalatestplus.mockito.MockitoSugar
import play.api.libs.json.{JsLookupResult, Json}

import scala.collection.JavaConverters._

class JsonEncoderSpec extends AnyWordSpec with Matchers with MockitoSugar {

  "Json-encoded message" should {
    "contain all required fields" in {

      val jsonEncoder = new JsonEncoder()
      val event       = mock[ILoggingEvent]

      when(event.getTimeStamp).thenReturn(1)
      when(event.getLevel).thenReturn(Level.INFO)
      when(event.getThreadName).thenReturn("my-thread")
      when(event.getFormattedMessage).thenReturn("my-message")
      when(event.getLoggerName).thenReturn("logger-name")
      when(event.getMDCPropertyMap).thenReturn(Map("myMdcProperty" -> "myMdcValue").asJava)

      val testException = new Exception("test-exception")
      val stringWriter  = new StringWriter()
      testException.printStackTrace(new PrintWriter(stringWriter))
      when(event.getThrowableProxy).thenReturn(new ThrowableProxy(testException))

      jsonEncoder.setContext {
        val ctx = new ContextBase()
        ctx.putProperty("myKey", "myValue")
        ctx
      }

      val result       = new String(jsonEncoder.encode(event), "UTF-8")
      val resultAsJson = Json.parse(result)

      (resultAsJson \ "app").asString           shouldBe "my-app-name"
      (resultAsJson \ "hostname").asString      shouldBe InetAddress.getLocalHost.getHostName
      (resultAsJson \ "timestamp").asString     shouldBe FastDateFormat.getInstance("yyyy-MM-dd HH:mm:ss.SSSZZ").format(1)
      (resultAsJson \ "message").asString       shouldBe "my-message"
      (resultAsJson \ "exception").asString     should include("test-exception")
      (resultAsJson \ "exception").asString     should include("java.lang.Exception")
      (resultAsJson \ "exception").asString     should include(stringWriter.toString)
      (resultAsJson \ "logger").asString        shouldBe "logger-name"
      (resultAsJson \ "thread").asString        shouldBe "my-thread"
      (resultAsJson \ "level").asString         shouldBe "INFO"
      (resultAsJson \ "mykey").asString         shouldBe "myValue"
      (resultAsJson \ "mymdcproperty").asString shouldBe "myMdcValue"

    }
  }

  implicit class JsLookupResultOps(jsLookupResult: JsLookupResult) {
    def asString: String = jsLookupResult.get.as[String]
  }

} 
Example 25
Source File: BasicInitContainerConfigurationStepSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.k8s.submit.steps.initcontainer

import scala.collection.JavaConverters._

import io.fabric8.kubernetes.api.model._
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Matchers.any
import org.mockito.Mockito.when
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer}
import org.apache.spark.deploy.k8s.Config._

class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter {

  private val SPARK_JARS = Seq(
    "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar")
  private val SPARK_FILES = Seq(
    "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt")
  private val JARS_DOWNLOAD_PATH = "/var/data/jars"
  private val FILES_DOWNLOAD_PATH = "/var/data/files"
  private val POD_LABEL = Map("bootstrap" -> "true")
  private val INIT_CONTAINER_NAME = "init-container"
  private val DRIVER_CONTAINER_NAME = "driver-container"

  @Mock
  private var podAndInitContainerBootstrap : InitContainerBootstrap = _

  before {
    MockitoAnnotations.initMocks(this)
    when(podAndInitContainerBootstrap.bootstrapInitContainer(
      any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] {
      override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = {
        val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer])
        pod.copy(
          pod = new PodBuilder(pod.pod)
            .withNewMetadata()
            .addToLabels("bootstrap", "true")
            .endMetadata()
            .withNewSpec().endSpec()
            .build(),
          initContainer = new ContainerBuilder()
            .withName(INIT_CONTAINER_NAME)
            .build(),
          mainContainer = new ContainerBuilder()
            .withName(DRIVER_CONTAINER_NAME)
            .build()
        )}})
  }

  test("additionalDriverSparkConf with mix of remote files and jars") {
    val baseInitStep = new BasicInitContainerConfigurationStep(
      SPARK_JARS,
      SPARK_FILES,
      JARS_DOWNLOAD_PATH,
      FILES_DOWNLOAD_PATH,
      podAndInitContainerBootstrap)
    val expectedDriverSparkConf = Map(
      JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH,
      FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH,
      INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar",
      INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt")
    val initContainerSpec = InitContainerSpec(
      Map.empty[String, String],
      Map.empty[String, String],
      new Container(),
      new Container(),
      new Pod,
      Seq.empty[HasMetadata])
    val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec)
    assert(expectedDriverSparkConf === returnContainerSpec.properties)
    assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME)
    assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME)
    assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL)
  }
} 
Example 26
Source File: YarnSchedulerBackendSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import scala.language.reflectiveCalls

import org.mockito.Mockito.when
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.serializer.JavaSerializer

class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with LocalSparkContext {

  test("RequestExecutors reflects node blacklist and is serializable") {
    sc = new SparkContext("local", "YarnSchedulerBackendSuite")
    val sched = mock[TaskSchedulerImpl]
    when(sched.sc).thenReturn(sc)
    val yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) {
      def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = {
        this.hostToLocalTaskCount = hostToLocalTaskCount
      }
    }
    val ser = new JavaSerializer(sc.conf).newInstance()
    for {
      blacklist <- IndexedSeq(Set[String](), Set("a", "b", "c"))
      numRequested <- 0 until 10
      hostToLocalCount <- IndexedSeq(
        Map[String, Int](),
        Map("a" -> 1, "b" -> 2)
      )
    } {
      yarnSchedulerBackend.setHostToLocalTaskCount(hostToLocalCount)
      when(sched.nodeBlacklist()).thenReturn(blacklist)
      val req = yarnSchedulerBackend.prepareRequestExecutors(numRequested)
      assert(req.requestedTotal === numRequested)
      assert(req.nodeBlacklist === blacklist)
      assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty)
      // Serialize to make sure serialization doesn't throw an error
      ser.serialize(req)
    }
    sc.stop()
  }

} 
Example 27
Source File: MasterWebUISuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 28
Source File: LogPageSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 29
Source File: RecordIOOutputFormatTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.protobuf

import java.io.ByteArrayOutputStream

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path}
import org.apache.hadoop.io.{BytesWritable, NullWritable}
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.mockito.Matchers.any
import org.mockito.Mockito.{verify, when}
import org.scalatest.{BeforeAndAfter, FlatSpec}
import org.scalatest.mock.MockitoSugar

import com.amazonaws.services.sagemaker.sparksdk.protobuf.RecordIOOutputFormat.SageMakerProtobufRecordWriter


class RecordIOOutputFormatTests extends FlatSpec with MockitoSugar with BeforeAndAfter {

  var sagemakerProtobufRecordWriter: SageMakerProtobufRecordWriter = _
  var mockOutputStream : FSDataOutputStream = _
  var byteArrayOutputStream: ByteArrayOutputStream = _
  var mockTaskAttemptContext: TaskAttemptContext = _
  var mockPath: Path = _
  var mockFileSystem: FileSystem = _

  before {
    byteArrayOutputStream = new ByteArrayOutputStream()
    mockOutputStream = mock[FSDataOutputStream]
    sagemakerProtobufRecordWriter = new SageMakerProtobufRecordWriter(mockOutputStream)
    mockTaskAttemptContext = mock[TaskAttemptContext]
    mockPath = mock[Path]
    mockFileSystem = mock[FileSystem]
  }

  it should "write an empty array of bytes" in {
    val bytesWritable = new BytesWritable(byteArrayOutputStream.toByteArray)

    val bytes = ProtobufConverter.byteArrayToRecordIOEncodedByteArray(bytesWritable.getBytes)
    sagemakerProtobufRecordWriter.write(NullWritable.get(), bytesWritable)

    verify(mockOutputStream).write(bytes, 0, bytes.length)
  }


  it should "write an array of bytes" in {
    val byteArray = Array[Byte](0, 0, 0, 0)
    byteArrayOutputStream.write(byteArray)
    val bytesWritable = new BytesWritable(byteArrayOutputStream.toByteArray)
    val bytes = ProtobufConverter.byteArrayToRecordIOEncodedByteArray(bytesWritable.getBytes)

    sagemakerProtobufRecordWriter.write(NullWritable.get(), bytesWritable)

    verify(mockOutputStream).write(bytes, 0, bytes.length)
  }

  it should "write an array of bytes, padding as necessary" in {
    byteArrayOutputStream.write(5)
    val bytesWritable = new BytesWritable(byteArrayOutputStream.toByteArray)
    val bytes = ProtobufConverter.byteArrayToRecordIOEncodedByteArray(bytesWritable.getBytes)

    sagemakerProtobufRecordWriter.write(NullWritable.get(), bytesWritable)

    verify(mockOutputStream).write(bytes, 0, bytes.length)
  }

  it should "write an array of bytes, padding only as much as necessary" in {
    byteArrayOutputStream.write(Array[Byte](0, 0, 0, 0, 0))
    val bytesWritable = new BytesWritable(byteArrayOutputStream.toByteArray)
    val bytes = ProtobufConverter.byteArrayToRecordIOEncodedByteArray(bytesWritable.getBytes)

    sagemakerProtobufRecordWriter.write(NullWritable.get(), bytesWritable)

    verify(mockOutputStream).write(bytes, 0, bytes.length)
  }

  it should "create a record writer from a FSDataOutputStream created by the filesystem" in {
    val mockTaskAttemptContext = mock[TaskAttemptContext]
    val mockPath = mock[Path]
    val mockFileSystem = mock[FileSystem]
    when(mockPath.getFileSystem(any[Configuration])).thenReturn(mockFileSystem)
    new RecordIOOutputFormat() {
      override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
        mockPath
      }
    }.getRecordWriter(mockTaskAttemptContext)
    verify(mockFileSystem).create(mockPath, true)

  }

} 
Example 30
Source File: MasterWebUISuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.util.Date

import scala.io.Source
import scala.language.postfixOps

import org.json4s.jackson.JsonMethods._
import org.json4s.JsonAST.{JNothing, JString, JInt}
import org.mockito.Mockito.{mock, when}
import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.MasterStateResponse
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.RpcEnv


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter {

  val masterPage = mock(classOf[MasterPage])
  val master = {
    val conf = new SparkConf
    val securityMgr = new SecurityManager(conf)
    val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr)
    val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf)
    master
  }
  val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage))

  before {
    masterWebUI.bind()
  }

  after {
    masterWebUI.stop()
  }

  test("list applications") {
    val worker = createWorkerInfo()
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue)
    activeApp.addExecutor(worker, 2)

    val workers = Array[WorkerInfo](worker)
    val activeApps = Array(activeApp)
    val completedApps = Array[ApplicationInfo]()
    val activeDrivers = Array[DriverInfo]()
    val completedDrivers = Array[DriverInfo]()
    val stateResponse = new MasterStateResponse(
      "host", 8080, None, workers, activeApps, completedApps,
      activeDrivers, completedDrivers, RecoveryState.ALIVE)

    when(masterPage.getMasterState).thenReturn(stateResponse)

    val resultJson = Source.fromURL(
      s"http://localhost:${masterWebUI.boundPort}/api/v1/applications")
      .mkString
    val parsedJson = parse(resultJson)
    val firstApp = parsedJson(0)

    assert(firstApp \ "id" === JString(activeApp.id))
    assert(firstApp \ "name" === JString(activeApp.desc.name))
    assert(firstApp \ "coresGranted" === JInt(2))
    assert(firstApp \ "maxCores" === JInt(4))
    assert(firstApp \ "memoryPerExecutorMB" === JInt(1234))
    assert(firstApp \ "coresPerExecutor" === JNothing)
  }

} 
Example 31
Source File: LogPageSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 32
Source File: StagePageSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  test("peak execution memory only displayed if unsafe is enabled") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
    // Disable unsafe and make sure it's not there
    val conf2 = new SparkConf(false).set(unsafeConf, "false")
    val html2 = renderStagePage(conf2).toString().toLowerCase
    assert(!html2.contains(targetString))
    // Avoid setting anything; it should be displayed by default
    val conf3 = new SparkConf(false)
    val html3 = renderStagePage(conf3).toString().toLowerCase
    assert(html3.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains("<td>10.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        val peakExecutionMemory = 10
        taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY,
          Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markSuccessful()
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 33
Source File: DiskBlockManagerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 34
Source File: PageServiceTest.scala    From theGardener   with Apache License 2.0 5 votes vote down vote up
package services


import controllers.dto.{PageFragment, PageFragmentContent}
import models._
import org.mockito.Mockito.when
import org.mockito._
import org.scalatest._
import org.scalatest.concurrent._
import org.scalatestplus.mockito._
import repositories._
import play.api.Configuration
import play.api.cache.SyncCacheApi
import services.clients.OpenApiClient

import scala.concurrent.ExecutionContext

class PageServiceTest extends WordSpec with MustMatchers with BeforeAndAfter with MockitoSugar with ScalaFutures {


  val projectRepository = mock[ProjectRepository]
  val directoryRepository = mock[DirectoryRepository]
  val pageRepository = mock[PageRepository]
  val featureService = mock[FeatureService]
  val cache =  new PageServiceCache( mock[SyncCacheApi])
  val gherkinRepository = mock[GherkinRepository]
  val config = mock[Configuration]
  val openApiClient = mock[OpenApiClient]
  implicit val ec = mock[ExecutionContext]


  when(config.getOptional[String]("application.baseUrl")).thenReturn(None)

  val pageService = new PageService(config, projectRepository, directoryRepository, pageRepository, gherkinRepository, openApiClient, cache)

  val variables = Seq(Variable(s"$${name1}", "value"),Variable(s"$${name2}", "value2"))
  val contentWithMarkdown = Seq(PageFragment("markdown", PageFragmentContent(Some(s"$${name1}"))))
  val contentWithExternalPage = Seq(PageFragment("includeExternalPage", PageFragmentContent(None,None,Some(s"$${name1}"))))
  val contentWithTwoFragment = contentWithMarkdown ++ contentWithExternalPage

  before {
    Mockito.reset(pageRepository)
  }

  "PageService" should {
    "Replace Variable in Markdown" in {
      pageService.replaceVariablesInMarkdown(contentWithMarkdown, variables) must contain theSameElementsAs Seq(PageFragment("markdown", PageFragmentContent(Some(s"value"))))
    }

    "Replace Variables in Markdown with external Link" in {
      pageService.replaceVariablesInMarkdown(contentWithTwoFragment, variables) must contain theSameElementsAs Seq(PageFragment("markdown", PageFragmentContent(Some(s"value"))),PageFragment("includeExternalPage", PageFragmentContent(None,None,Some(s"value"))))
    }
  }




} 
Example 35
Source File: RouterTestHelpers.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.router2

import org.coursera.naptime.resources.CollectionResource
import org.mockito.Mockito.when

trait RouterTestHelpers {

  def setupParentMockCalls[T <: CollectionResource[_, _, _]](mock: T, instanceImpl: T): Unit = {
    when(mock.resourceName).thenReturn(instanceImpl.resourceName)
    when(mock.resourceVersion).thenReturn(instanceImpl.resourceVersion)
    when(mock.pathParser).thenReturn(instanceImpl.pathParser.asInstanceOf[mock.PathParser])
  }

  def assertRouted(resource: CollectionResource[_, _, _], urlFragment: String): Unit = {
    assert(
      !resource.optParse(urlFragment).isEmpty,
      s"Resource: ${resource.getClass.getName} did not accept url fragment: $urlFragment")
  }

  def assertNotRouted(resource: CollectionResource[_, _, _], urlFragment: String): Unit = {
    assert(
      resource.optParse(urlFragment).isEmpty,
      s"Resource: ${resource.getClass.getName} did accept url fragment: $urlFragment")
  }
} 
Example 36
Source File: PaymentsControllerSpec.scala    From pertax-frontend   with Apache License 2.0 5 votes vote down vote up
package controllers

import config.ConfigDecorator
import connectors._
import controllers.auth.requests.UserRequest
import controllers.auth.{AuthJourney, WithBreadcrumbAction}
import models.CreatePayment
import org.joda.time.DateTime
import org.mockito.Matchers.any
import org.mockito.Mockito.when
import org.scalatestplus.mockito.MockitoSugar
import play.api.Application
import play.api.i18n.MessagesApi
import play.api.inject.bind
import play.api.mvc.{ActionBuilder, MessagesControllerComponents, Request, Result}
import play.api.test.FakeRequest
import play.api.test.Helpers.{redirectLocation, _}
import uk.gov.hmrc.renderer.TemplateRenderer
import uk.gov.hmrc.time.CurrentTaxYear
import util.UserRequestFixture.buildUserRequest
import util.{ActionBuilderFixture, BaseSpec}

import scala.concurrent.{ExecutionContext, Future}

class PaymentsControllerSpec extends BaseSpec with CurrentTaxYear with MockitoSugar {

  override def now: () => DateTime = DateTime.now

  lazy val fakeRequest = FakeRequest("", "")

  val mockPayConnector = mock[PayApiConnector]
  val mockAuthJourney = mock[AuthJourney]

  override implicit lazy val app: Application = localGuiceApplicationBuilder()
    .overrides(
      bind[PayApiConnector].toInstance(mockPayConnector),
      bind[AuthJourney].toInstance(mockAuthJourney)
    )
    .build()

  def controller =
    new PaymentsController(
      mockPayConnector,
      mockAuthJourney,
      injected[WithBreadcrumbAction],
      injected[MessagesControllerComponents]
    )(mockLocalPartialRetriever, injected[ConfigDecorator], mock[TemplateRenderer], injected[ExecutionContext])

  when(mockAuthJourney.authWithPersonalDetails).thenReturn(new ActionBuilderFixture {
    override def invokeBlock[A](request: Request[A], block: UserRequest[A] => Future[Result]): Future[Result] =
      block(
        buildUserRequest(
          request = request
        ))
  })

  "makePayment" should {
    "redirect to the response's nextUrl" in {

      val expectedNextUrl = "someNextUrl"
      val createPaymentResponse = CreatePayment("someJourneyId", expectedNextUrl)

      when(mockPayConnector.createPayment(any())(any(), any()))
        .thenReturn(Future.successful(Some(createPaymentResponse)))

      val result = controller.makePayment()(FakeRequest())
      status(result) shouldBe SEE_OTHER
      redirectLocation(result) shouldBe Some("someNextUrl")
    }

    "redirect to a BAD_REQUEST page if createPayment failed" in {

      when(mockPayConnector.createPayment(any())(any(), any()))
        .thenReturn(Future.successful(None))

      val result = controller.makePayment()(FakeRequest())
      status(result) shouldBe BAD_REQUEST
    }
  }
} 
Example 37
Source File: EnrolmentsConnectorSpec.scala    From pertax-frontend   with Apache License 2.0 5 votes vote down vote up
package connectors

import models._
import org.joda.time.DateTime
import org.mockito.Matchers.{any, eq => eqTo}
import org.mockito.Mockito.when
import org.scalatest.EitherValues
import org.scalatest.Inspectors.forAll
import org.scalatest.concurrent.ScalaFutures
import org.scalatestplus.mockito.MockitoSugar
import play.api.http.Status._
import play.api.libs.json.{JsObject, JsResultException, Json}
import uk.gov.hmrc.http.{HttpException, HttpResponse}
import uk.gov.hmrc.play.bootstrap.http.DefaultHttpClient
import util.BaseSpec

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future

class EnrolmentsConnectorSpec extends BaseSpec with MockitoSugar with ScalaFutures with EitherValues {

  val http = mock[DefaultHttpClient]
  val connector = new EnrolmentsConnector(http, config)
  val baseUrl = config.enrolmentStoreProxyUrl

  "getAssignedEnrolments" should {
    val utr = "1234500000"
    val url = s"$baseUrl/enrolment-store/enrolments/IR-SA~UTR~$utr/users"

    "Return the error message for a BAD_REQUEST response" in {
      when(http.GET[HttpResponse](eqTo(url))(any(), any(), any()))
        .thenReturn(Future.successful(HttpResponse(BAD_REQUEST)))

      connector.getUserIdsWithEnrolments(utr).futureValue.left.value should include(BAD_REQUEST.toString)
    }

    "NO_CONTENT response should return no enrolments" in {
      when(http.GET[HttpResponse](eqTo(url))(any(), any(), any()))
        .thenReturn(Future.successful(HttpResponse(NO_CONTENT)))

      connector.getUserIdsWithEnrolments(utr).futureValue.right.value shouldBe Seq.empty
    }

    "query users with no principal enrolment returns empty enrolments" in {
      val json = Json.parse("""
                              |{
                              |    "principalUserIds": [],
                              |     "delegatedUserIds": []
                              |}""".stripMargin)

      when(http.GET[HttpResponse](eqTo(url))(any(), any(), any()))
        .thenReturn(Future.successful(HttpResponse(OK, Some(json))))

      connector.getUserIdsWithEnrolments(utr).futureValue.right.value shouldBe Seq.empty
    }

    "query users with assigned enrolment return two principleIds" in {
      val json = Json.parse("""
                              |{
                              |    "principalUserIds": [
                              |       "ABCEDEFGI1234567",
                              |       "ABCEDEFGI1234568"
                              |    ],
                              |    "delegatedUserIds": [
                              |     "dont care"
                              |    ]
                              |}""".stripMargin)

      when(http.GET[HttpResponse](eqTo(url))(any(), any(), any()))
        .thenReturn(Future.successful(HttpResponse(OK, Some(json))))

      val expected = Seq("ABCEDEFGI1234567", "ABCEDEFGI1234568")

      connector.getUserIdsWithEnrolments(utr).futureValue.right.value shouldBe expected
    }
  }
} 
Example 38
Source File: PayApiConnectorSpec.scala    From pertax-frontend   with Apache License 2.0 5 votes vote down vote up
package connectors

import models.{CreatePayment, PaymentRequest}
import org.mockito.Matchers.{any, eq => eqTo}
import org.mockito.Mockito.when
import org.scalatest.concurrent.ScalaFutures
import org.scalatestplus.mockito.MockitoSugar
import play.api.http.Status._
import play.api.libs.json.{JsResultException, Json}
import uk.gov.hmrc.http.HttpResponse
import uk.gov.hmrc.play.bootstrap.http.DefaultHttpClient
import util.BaseSpec

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future

class PayApiConnectorSpec extends BaseSpec with MockitoSugar with ScalaFutures {

  val http = mock[DefaultHttpClient]
  val connector = new PayApiConnector(http, config)
  val paymentRequest = PaymentRequest(config, "some utr")
  val postUrl = config.makeAPaymentUrl

  "createPayment" should {
    "parse the json load for a successful CREATED response" in {
      val json = Json.obj(
        "journeyId" -> "exampleJourneyId",
        "nextUrl"   -> "testNextUrl"
      )

      when(
        http.POST[PaymentRequest, HttpResponse](eqTo(postUrl), eqTo(paymentRequest), any())(any(), any(), any(), any()))
        .thenReturn(Future.successful(HttpResponse(CREATED, Some(json))))

      connector.createPayment(paymentRequest).futureValue shouldBe Some(
        CreatePayment("exampleJourneyId", "testNextUrl"))
    }

    "Returns a None when the status code is not CREATED" in {
      when(
        http.POST[PaymentRequest, HttpResponse](eqTo(postUrl), eqTo(paymentRequest), any())(any(), any(), any(), any()))
        .thenReturn(Future.successful(HttpResponse(BAD_REQUEST)))

      connector.createPayment(paymentRequest).futureValue shouldBe None
    }

    "Throws a JsResultException when given bad json" in {
      val badJson = Json.obj("abc" -> "invalidData")

      when(
        http.POST[PaymentRequest, HttpResponse](eqTo(postUrl), eqTo(paymentRequest), any())(any(), any(), any(), any()))
        .thenReturn(Future.successful(HttpResponse(CREATED, Some(badJson))))

      val f = connector.createPayment(paymentRequest)
      whenReady(f.failed) { e =>
        e shouldBe a[JsResultException]
      }
    }
  }
} 
Example 39
Source File: InMemoryLedgerReaderWriterSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.memory

import com.codahale.metrics.MetricRegistry
import com.daml.ledger.api.testing.utils.AkkaBeforeAndAfterAll
import com.daml.ledger.participant.state.kvutils.api.CommitMetadata
import com.daml.ledger.participant.state.v1.{ParticipantId, SubmissionResult}
import com.daml.ledger.validator.{BatchedValidatingCommitter, LedgerStateOperations}
import com.daml.lf.data.Ref
import com.daml.metrics.Metrics
import com.daml.platform.akkastreams.dispatcher.Dispatcher
import com.google.protobuf.ByteString
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.{times, verify, when}
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{AsyncWordSpec, Matchers}

import scala.concurrent.{ExecutionContext, Future}

class InMemoryLedgerReaderWriterSpec
    extends AsyncWordSpec
    with AkkaBeforeAndAfterAll
    with Matchers
    with MockitoSugar {
  "commit" should {
    "not signal new head in case of failure" in {
      val mockDispatcher = mock[Dispatcher[Index]]
      val mockCommitter = mock[BatchedValidatingCommitter[Index]]
      when(
        mockCommitter.commit(
          anyString(),
          any[ByteString](),
          any[ParticipantId](),
          any[LedgerStateOperations[Index]])(any[ExecutionContext]()))
        .thenReturn(
          Future.successful(SubmissionResult.InternalError("Validation failed with an exception")))
      val instance = new InMemoryLedgerReaderWriter(
        Ref.ParticipantId.assertFromString("participant ID"),
        "ledger ID",
        mockDispatcher,
        InMemoryState.empty,
        mockCommitter,
        new Metrics(new MetricRegistry)
      )

      instance
        .commit("correlation ID", ByteString.copyFromUtf8("some bytes"), CommitMetadata.Empty)
        .map { actual =>
          verify(mockDispatcher, times(0)).signalNewHead(anyInt())
          actual should be(a[SubmissionResult.InternalError])
        }
    }
  }
} 
Example 40
Source File: StoreBackedCommandExecutorSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.platform.apiserver.execution

import com.daml.ledger.api.domain.Commands
import com.daml.ledger.participant.state.index.v2.{ContractStore, IndexPackagesService}
import com.daml.lf.crypto.Hash
import com.daml.lf.data.Ref.ParticipantId
import com.daml.lf.data.{ImmArray, Ref, Time}
import com.daml.lf.engine.{Engine, ResultDone}
import com.daml.lf.transaction.Transaction
import com.daml.lf.transaction.test.TransactionBuilder
import com.daml.logging.LoggingContext
import com.daml.metrics.Metrics
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.when
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{AsyncWordSpec, Matchers}

class StoreBackedCommandExecutorSpec extends AsyncWordSpec with MockitoSugar with Matchers {
  private val emptyTransaction =
    Transaction.SubmittedTransaction(TransactionBuilder.Empty)
  private val emptyTransactionMetadata = Transaction.Metadata(
    submissionSeed = None,
    submissionTime = Time.Timestamp.now(),
    usedPackages = Set.empty,
    dependsOnTime = false,
    nodeSeeds = ImmArray.empty,
    byKeyNodes = ImmArray.empty)

  "execute" should {
    "add interpretation time to result" in {
      val mockEngine = mock[Engine]
      when(mockEngine.submit(any[com.daml.lf.command.Commands], any[ParticipantId], any[Hash]))
        .thenReturn(
          ResultDone[(Transaction.SubmittedTransaction, Transaction.Metadata)](
            (emptyTransaction, emptyTransactionMetadata)
          )
        )
      val instance = new StoreBackedCommandExecutor(
        mockEngine,
        Ref.ParticipantId.assertFromString("anId"),
        mock[IndexPackagesService],
        mock[ContractStore],
        mock[Metrics])
      val mockDomainCommands = mock[Commands]
      val mockLfCommands = mock[com.daml.lf.command.Commands]
      when(mockLfCommands.ledgerEffectiveTime).thenReturn(Time.Timestamp.now())
      when(mockDomainCommands.workflowId).thenReturn(None)
      when(mockDomainCommands.commands).thenReturn(mockLfCommands)

      LoggingContext.newLoggingContext { implicit context =>
        instance.execute(mockDomainCommands, Hash.hashPrivateKey("a key")).map { actual =>
          actual.right.foreach { actualResult =>
            actualResult.interpretationTimeNanos should be > 0L
          }
          succeed
        }
      }
    }
  }
} 
Example 41
Source File: BatchedValidatingCommitterSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.validator

import java.time.Instant

import akka.stream.Materializer
import com.daml.ledger.api.testing.utils.AkkaBeforeAndAfterAll
import com.daml.ledger.participant.state.v1.{ParticipantId, SubmissionResult}
import com.daml.ledger.validator.TestHelper.aParticipantId
import com.daml.ledger.validator.batch.BatchedSubmissionValidator
import com.google.protobuf.ByteString
import org.mockito.ArgumentMatchers.{any, anyString}
import org.mockito.Mockito.when
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{AsyncWordSpec, Matchers}

import scala.concurrent.{ExecutionContext, Future}

class BatchedValidatingCommitterSpec
    extends AsyncWordSpec
    with AkkaBeforeAndAfterAll
    with Matchers
    with MockitoSugar {
  "commit" should {
    "return Acknowledged in case of success" in {
      val mockValidator = mock[BatchedSubmissionValidator[Unit]]
      when(
        mockValidator.validateAndCommit(
          any[ByteString](),
          anyString(),
          any[Instant](),
          any[ParticipantId](),
          any[DamlLedgerStateReader](),
          any[CommitStrategy[Unit]]())(any[Materializer](), any[ExecutionContext]()))
        .thenReturn(Future.unit)
      val instance =
        BatchedValidatingCommitter[Unit](() => Instant.now(), mockValidator)

      instance
        .commit("", ByteString.EMPTY, aParticipantId, mock[LedgerStateOperations[Unit]])
        .map { actual =>
          actual shouldBe SubmissionResult.Acknowledged
        }
    }

    "return InternalError in case of an exception" in {
      val mockValidator = mock[BatchedSubmissionValidator[Unit]]
      when(
        mockValidator.validateAndCommit(
          any[ByteString](),
          anyString(),
          any[Instant](),
          any[ParticipantId](),
          any[DamlLedgerStateReader](),
          any[CommitStrategy[Unit]]())(any[Materializer](), any[ExecutionContext]()))
        .thenReturn(Future.failed(new IllegalArgumentException("Validation failure")))
      val instance = BatchedValidatingCommitter[Unit](() => Instant.now(), mockValidator)

      instance
        .commit("", ByteString.EMPTY, aParticipantId, mock[LedgerStateOperations[Unit]])
        .map { actual =>
          actual shouldBe SubmissionResult.InternalError("Validation failure")
        }
    }
  }
} 
Example 42
Source File: LogAppendingCommitStrategySpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.validator

import org.scalatest.mockito.MockitoSugar
import org.scalatest.{AsyncWordSpec, Matchers}
import TestHelper._
import com.daml.ledger.participant.state.kvutils.DamlKvutils.{DamlStateKey, DamlStateValue}
import com.daml.ledger.participant.state.kvutils.Envelope
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import com.google.protobuf.ByteString
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.{times, verify, when}

import scala.concurrent.Future

class LogAppendingCommitStrategySpec extends AsyncWordSpec with Matchers with MockitoSugar {
  "commit" should {
    "return index from appendToLog" in {
      val mockLedgerStateOperations = mock[LedgerStateOperations[Long]]
      val expectedIndex = 1234L
      when(mockLedgerStateOperations.appendToLog(any[Key](), any[Value]()))
        .thenReturn(Future.successful(expectedIndex))
      val instance =
        new LogAppendingCommitStrategy[Long](
          mockLedgerStateOperations,
          DefaultStateKeySerializationStrategy)

      instance
        .commit(aParticipantId, "a correlation ID", aLogEntryId(), aLogEntry, Map.empty, Map.empty)
        .map { actualIndex =>
          verify(mockLedgerStateOperations, times(1)).appendToLog(any[Key](), any[Value]())
          verify(mockLedgerStateOperations, times(0)).writeState(any[Seq[(Key, Value)]]())
          actualIndex should be(expectedIndex)
        }
    }

    "write keys serialized according to strategy" in {
      val mockLedgerStateOperations = mock[LedgerStateOperations[Long]]
      val actualOutputStateBytesCaptor = ArgumentCaptor
        .forClass(classOf[Seq[(Key, Value)]])
        .asInstanceOf[ArgumentCaptor[Seq[(Key, Value)]]]
      when(mockLedgerStateOperations.writeState(actualOutputStateBytesCaptor.capture()))
        .thenReturn(Future.unit)
      when(mockLedgerStateOperations.appendToLog(any[Key](), any[Value]()))
        .thenReturn(Future.successful(0L))
      val mockStateKeySerializationStrategy = mock[StateKeySerializationStrategy]
      val expectedStateKey = ByteString.copyFromUtf8("some key")
      when(mockStateKeySerializationStrategy.serializeStateKey(any[DamlStateKey]()))
        .thenReturn(expectedStateKey)
      val expectedOutputStateBytes = Seq((expectedStateKey, Envelope.enclose(aStateValue)))
      val instance =
        new LogAppendingCommitStrategy[Long](
          mockLedgerStateOperations,
          mockStateKeySerializationStrategy)

      instance
        .commit(
          aParticipantId,
          "a correlation ID",
          aLogEntryId(),
          aLogEntry,
          Map.empty,
          Map(aStateKey -> aStateValue))
        .map { _: Long =>
          verify(mockStateKeySerializationStrategy, times(1)).serializeStateKey(aStateKey)
          verify(mockLedgerStateOperations, times(1)).writeState(any[Seq[(Key, Value)]]())
          actualOutputStateBytesCaptor.getValue should be(expectedOutputStateBytes)
        }
    }
  }

  private val aStateKey: DamlStateKey = DamlStateKey
    .newBuilder()
    .setContractId(1.toString)
    .build

  private val aStateValue: DamlStateValue = DamlStateValue.getDefaultInstance
} 
Example 43
Source File: MasterWebUISuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 44
Source File: LogPageSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 45
Source File: StagePageSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.exec.ExecutorsListener
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  private val peakExecutionMemory = 10

  test("peak execution memory only displayed if unsafe is enabled") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
    // Disable unsafe and make sure it's not there
    val conf2 = new SparkConf(false).set(unsafeConf, "false")
    val html2 = renderStagePage(conf2).toString().toLowerCase
    assert(!html2.contains(targetString))
    // Avoid setting anything; it should be displayed by default
    val conf3 = new SparkConf(false)
    val html3 = renderStagePage(conf3).toString().toLowerCase
    assert(html3.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val unsafeConf = "spark.sql.unsafe.enabled"
    val conf = new SparkConf(false).set(unsafeConf, "true")
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.executorsListener).thenReturn(executorsListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markFinished(TaskState.FINISHED)
        val taskMetrics = TaskMetrics.empty
        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 46
Source File: AuthActionSelectorSpec.scala    From nisp-frontend   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.nisp.controllers.auth

import org.mockito.Mockito.when
import org.scalatest.mockito.MockitoSugar
import org.scalatestplus.play.OneAppPerSuite
import uk.gov.hmrc.nisp.config.ApplicationConfig
import uk.gov.hmrc.play.test.UnitSpec

class AuthActionSelectorSpec extends UnitSpec with OneAppPerSuite with MockitoSugar {

  "The decide method" should {

    "use VerifyAuthActionImpl when IV disabled" in {
      val applicationConfig: ApplicationConfig = mock[ApplicationConfig]

      when(applicationConfig.identityVerification).thenReturn(false)
      AuthActionSelector.decide(applicationConfig) shouldBe a[VerifyAuthActionImpl]
    }

    "use AuthActionImpl when IV enabled" in {
      val applicationConfig: ApplicationConfig = mock[ApplicationConfig]

      when(applicationConfig.identityVerification).thenReturn(true)
      AuthActionSelector.decide(applicationConfig) shouldBe an[AuthActionImpl]
    }
  }
} 
Example 47
Source File: NispFrontendControllerSpec.scala    From nisp-frontend   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.nisp.controllers

import org.slf4j.{Logger => Slf4JLogger}
import org.mockito.Mockito.{verify, when}
import org.mockito.ArgumentMatchers._
import org.scalatest.mock.MockitoSugar
import org.scalatestplus.play.OneAppPerSuite
import play.api.Logger
import play.api.http.Status
import play.api.mvc.Result
import play.api.test.FakeRequest
import uk.gov.hmrc.nisp.config.{ApplicationGlobal, ApplicationGlobalTrait}
import uk.gov.hmrc.nisp.helpers.{MockApplicationGlobal, MockCachedStaticHtmlPartialRetriever}
import uk.gov.hmrc.nisp.utils.MockTemplateRenderer
import uk.gov.hmrc.play.test.UnitSpec
import uk.gov.hmrc.renderer.TemplateRenderer

class NispFrontendControllerSpec extends UnitSpec with MockitoSugar with OneAppPerSuite {

  val mockLogger: Slf4JLogger = mock[Slf4JLogger]
  when(mockLogger.isErrorEnabled).thenReturn(true)

  def controller = new NispFrontendController {
    override val logger = new Logger(mockLogger)
    val cachedStaticHtmlPartialRetriever = MockCachedStaticHtmlPartialRetriever
    override implicit val templateRenderer: TemplateRenderer = MockTemplateRenderer
    override  val applicationGlobal:ApplicationGlobalTrait = MockApplicationGlobal
  }

  implicit val request = FakeRequest()

  "onError" should {
    "should log error details" in {
      val result: Result =  controller.onError(new Exception())
      verify(mockLogger).error(anyString(), any[Exception])
    }

    "should return an Internal Server Error (500)" in {
      val result: Result =  controller.onError(new Exception())
      status(result) shouldBe Status.INTERNAL_SERVER_ERROR
    }
  }

} 
Example 48
Source File: BackendConnectorSpec.scala    From nisp-frontend   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.nisp.connectors

import org.mockito.Mockito.when
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.mock.MockitoSugar
import play.api.libs.json.Json
import uk.gov.hmrc.http.cache.client.SessionCache
import uk.gov.hmrc.nisp.helpers.{MockMetricsService, MockSessionCache}
import uk.gov.hmrc.nisp.models.NationalInsuranceRecord
import uk.gov.hmrc.nisp.models.enums.APIType
import uk.gov.hmrc.nisp.services.MetricsService
import uk.gov.hmrc.nisp.utils.JsonDepersonaliser
import uk.gov.hmrc.play.test.UnitSpec

import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Failure, Success}
import uk.gov.hmrc.http.{HeaderCarrier, HttpGet, HttpResponse}

class BackendConnectorSpec extends UnitSpec with MockitoSugar with ScalaFutures {

  val mockHttp: HttpGet = mock[HttpGet]
  
  object BackendConnectorImpl extends BackendConnector {
    override def http: HttpGet = mockHttp
    override def sessionCache: SessionCache = MockSessionCache
    override def serviceUrl: String = "national-insurance"
    override val metricsService: MetricsService = MockMetricsService

    def getNationalInsurance()(implicit headerCarrier: HeaderCarrier): Future[NationalInsuranceRecord] = {
      val urlToRead = s"$serviceUrl/ni"
      retrieveFromCache[NationalInsuranceRecord](APIType.NationalInsurance, urlToRead)(headerCarrier, NationalInsuranceRecord.formats)
    }
  }

  implicit val headerCarrier = HeaderCarrier(extraHeaders = Seq("Accept" -> "application/vnd.hmrc.1.0+json"))

  "connectToMicroservice" should {
    "should return depersonalised JSON" in {
      val json = Json.obj(
        "qualifyingYearsPriorTo1975" -> 0,
        "numberOfGaps" -> 6,
        "numberOfGapsPayable" -> 4,
        "dateOfEntry" -> "1975-08-01",
        "homeResponsibilitiesProtection" -> false,
        "earningsIncludedUpTo" -> "2016-04-05",
        "_embedded" -> Json.obj(
          "taxYears" -> Json.arr()
        )
      )

      val depersonalisedJson =  JsonDepersonaliser.depersonalise(json) match {
        case Success(s) => s
        case Failure(_) => fail()
      }

      val Ok = 200
      val response = Future(HttpResponse(Ok, Option.apply(json)))
      when(mockHttp.GET[HttpResponse]("national-insurance/ni")).thenReturn(response)

      val future: Future[NationalInsuranceRecord] = BackendConnectorImpl.getNationalInsurance()

      whenReady(future.failed) {
        t: Throwable =>
          t.getMessage.contains(depersonalisedJson) shouldBe true
          t.getMessage.contains("2016-04-05") shouldBe false
      }
    }
  }

} 
Example 49
Source File: JsonRequestSpec.scala    From play-ws   with Apache License 2.0 5 votes vote down vote up
package play.api.libs.ws.ahc

import java.nio.charset.StandardCharsets

import akka.actor.ActorSystem
import akka.stream.Materializer
import akka.util.ByteString
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
import org.mockito.Mockito.when
import org.specs2.mock.Mockito

import org.specs2.mutable.Specification
import org.specs2.specification.AfterAll
import play.api.libs.json.JsString
import play.api.libs.json.JsValue
import play.api.libs.json.Json
import play.api.libs.ws.JsonBodyReadables
import play.api.libs.ws.JsonBodyWritables
import play.libs.ws.DefaultObjectMapper
import play.shaded.ahc.org.asynchttpclient.Response

import scala.io.Codec


class JsonRequestSpec extends Specification with Mockito with AfterAll with JsonBodyWritables {
  sequential

  implicit val system       = ActorSystem()
  implicit val materializer = Materializer.matFromSystem

  override def afterAll: Unit = {
    system.terminate()
  }

  "set a json node" in {
    val jsValue = Json.obj("k1" -> JsString("v1"))
    val client  = mock[StandaloneAhcWSClient]
    val req = new StandaloneAhcWSRequest(client, "http://playframework.com/", null)
      .withBody(jsValue)
      .asInstanceOf[StandaloneAhcWSRequest]
      .buildRequest()

    req.getHeaders.get("Content-Type") must be_==("application/json")
    ByteString.fromArray(req.getByteData).utf8String must be_==("""{"k1":"v1"}""")
  }

  "set a json node using the default object mapper" in {
    val objectMapper = DefaultObjectMapper.instance

    implicit val jsonReadable = body(objectMapper)
    val jsonNode              = objectMapper.readTree("""{"k1":"v1"}""")
    val client                = mock[StandaloneAhcWSClient]
    val req = new StandaloneAhcWSRequest(client, "http://playframework.com/", null)
      .withBody(jsonNode)
      .asInstanceOf[StandaloneAhcWSRequest]
      .buildRequest()

    req.getHeaders.get("Content-Type") must be_==("application/json")
    ByteString.fromArray(req.getByteData).utf8String must be_==("""{"k1":"v1"}""")
  }

  "read an encoding of UTF-8" in {
    val json = io.Source.fromResource("test.json")(Codec.ISO8859).getLines.mkString

    val ahcResponse = mock[Response]
    val response    = new StandaloneAhcWSResponse(ahcResponse)

    when(ahcResponse.getResponseBody(StandardCharsets.UTF_8)).thenReturn(json)
    when(ahcResponse.getContentType).thenReturn("application/json")

    val value: JsValue = JsonBodyReadables.readableAsJson.transform(response)
    verify(ahcResponse, times(1)).getResponseBody(StandardCharsets.UTF_8)
    verify(ahcResponse, times(1)).getContentType
    value.toString must beEqualTo(json)
  }

  "read an encoding of ISO-8859-1" in {
    val json = io.Source.fromResource("test.json")(Codec.ISO8859).getLines.mkString

    val ahcResponse = mock[Response]
    val response    = new StandaloneAhcWSResponse(ahcResponse)

    when(ahcResponse.getResponseBody(StandardCharsets.ISO_8859_1)).thenReturn(json)
    when(ahcResponse.getContentType).thenReturn("application/json;charset=iso-8859-1")

    val value: JsValue = JsonBodyReadables.readableAsJson.transform(response)
    verify(ahcResponse, times(1)).getResponseBody(StandardCharsets.ISO_8859_1)
    verify(ahcResponse, times(1)).getContentType
    value.toString must beEqualTo(json)
  }
} 
Example 50
Source File: LaboratoryControllerSuccessfulSpec.scala    From Aton   with GNU General Public License v3.0 5 votes vote down vote up
package controllers.admin

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import org.mockito.Matchers.any
import org.mockito.Mockito.when
import com.google.inject.ImplementedBy
import com.google.inject.Inject
import jp.t2v.lab.play2.auth.test.Helpers.AuthFakeRequest
import model.Laboratory
import model.Role
import model.User
import model.form.LaboratoryForm
import model.form.data.LaboratoryFormData
import model.json.{LaboratoryJson, LoginJson}
import play.api.Environment
import play.api.i18n.MessagesApi
import play.api.libs.json.Json
import play.api.test.FakeRequest
import play.test.WithApplication
import services.LaboratoryService
import services.RoomService
import services.UserService
import services.impl.LaboratoryServiceImpl
import services.state
import services.state.ActionState
import test.ControllerTest


class LaboratoryControllerSuccessfulSpec extends LaboratoryControllerSpec {
  val labService: LaboratoryService = mockLaboratoryService(state.ActionCompleted)
  // Controller to be tested, with the dependencies
  lazy val controller = new LaboratoryController(labService, messagesApi)(userService, executionContext, environment)

  "Laboratory Controller on successful operations" should {
    "return Ok <200> status on receiving an edited laboratory" in {
      import laboratory._
      val laboratoryData = LaboratoryFormData(name, location, administration)
      val laboratoryForm = LaboratoryForm.form.fill(laboratoryData)
      val result = controller.update.apply {
        FakeRequest()
          .withLoggedIn(controller)(loggedInUser)
          .withJsonBody(Json.toJson(laboratory))
      }
      assertFutureResultStatus(result, 200)
    }

    "return Ok <200> status on deleting a laboratory" in {
      val result = controller.delete(laboratory.id).apply {
        FakeRequest()
          .withLoggedIn(controller)(LoginJson("admin", "adminaton"))
      }
      assertFutureResultStatus(result, 200)
    }

    "return Ok <200> status on adding a new laboratory" in {
      import laboratory._
      val laboratoryData = LaboratoryJson(name, location, administration)
      val result = controller.add.apply {
        FakeRequest()
          .withLoggedIn(controller)(loggedInUser)
          .withJsonBody(Json.toJson(laboratoryData))
      }
      assertFutureResultStatus(result, 200)
    }

    "return Ok <200> status when listing all laboratories" in pending
    "return laboratory list json when listing all laboratories" in pending
  }
} 
Example 51
Source File: SqlTimestampProviderTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.providers

import com.aol.one.dwh.bandarlog.connectors.JdbcConnector
import com.aol.one.dwh.infra.config.Table
import com.aol.one.dwh.infra.sql.{Query, VerticaMaxValuesQuery}
import org.mockito.Matchers.any
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class SqlTimestampProviderTest extends FunSuite with MockitoSugar {

  private val query = VerticaMaxValuesQuery(Table("table", List("column"), None))
  private val jdbcConnector = mock[JdbcConnector]
  private val sqlTimestampProvider = new SqlTimestampProvider(jdbcConnector, query)

  test("check timestamp value by connector and query") {
    val resultTimestamp = Some(1234567890L)
    when(jdbcConnector.runQuery(any(classOf[Query]), any())).thenReturn(resultTimestamp)

    val result = sqlTimestampProvider.provide()

    assert(result.getValue == resultTimestamp)
  }

  test("return none if can't get timestamp value") {
    when(jdbcConnector.runQuery(any(classOf[Query]), any())).thenReturn(None)

    val result = sqlTimestampProvider.provide()

    assert(result.getValue.isEmpty)
  }
} 
Example 52
Source File: FilterTest.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari.graphql.controllers.filters

import org.coursera.naptime.ari.graphql.GraphqlSchemaProvider
import org.coursera.naptime.ari.graphql.Models
import org.coursera.naptime.ari.graphql.SangriaGraphQlContext
import org.coursera.naptime.ari.graphql.SangriaGraphQlSchemaBuilder
import org.coursera.naptime.ari.graphql.models.MergedCourse
import org.coursera.naptime.ari.graphql.models.MergedInstructor
import org.coursera.naptime.ari.graphql.models.MergedPartner
import org.mockito.Mockito.when
import org.scalatest.concurrent.IntegrationPatience
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.junit.AssertionsForJUnit
import org.scalatest.mockito.MockitoSugar
import play.api.libs.json.Json
import play.api.test.FakeRequest
import sangria.parser.QueryParser
import sangria.schema.Schema

import scala.concurrent.Future

trait FilterTest
    extends AssertionsForJUnit
    with MockitoSugar
    with ScalaFutures
    with IntegrationPatience {

  val baseOutgoingQuery = OutgoingQuery(Json.obj(), None)

  def noopFilter(incomingQuery: IncomingQuery) = {
    Future.successful(baseOutgoingQuery)
  }

  def exceptionThrowingFilter(incomingQuery: IncomingQuery): Future[OutgoingQuery] = {
    assert(false, "This filter should not be run")
    Future.successful(baseOutgoingQuery)
  }

  val filter: Filter

  val defaultQuery =
    """
      |query {
      |  __schema {
      |    queryType {
      |      name
      |    }
      |  }
      |}
    """.stripMargin

  val graphqlSchemaProvider = mock[GraphqlSchemaProvider]

  val allResources = Set(Models.courseResource, Models.instructorResource, Models.partnersResource)

  val schemaTypes = Map(
    "org.coursera.naptime.ari.graphql.models.MergedCourse" -> MergedCourse.SCHEMA,
    "org.coursera.naptime.ari.graphql.models.MergedPartner" -> MergedPartner.SCHEMA,
    "org.coursera.naptime.ari.graphql.models.MergedInstructor" -> MergedInstructor.SCHEMA)
  val builder = new SangriaGraphQlSchemaBuilder(allResources, schemaTypes)

  val schema = builder.generateSchema().data.asInstanceOf[Schema[SangriaGraphQlContext, Any]]
  when(graphqlSchemaProvider.schema).thenReturn(schema)

  def generateIncomingQuery(query: String = defaultQuery) = {
    val document = QueryParser.parse(query).get
    val header = FakeRequest("POST", s"/graphql").withBody(query)
    val variables = Json.obj()
    val operation = None
    IncomingQuery(document, header, variables, operation, debugMode = false)
  }

  def run(incomingQuery: IncomingQuery): Future[OutgoingQuery] = {
    filter.apply(noopFilter)(incomingQuery)
  }

  def ensureNotPropagated(incomingQuery: IncomingQuery): Future[OutgoingQuery] = {
    filter.apply(exceptionThrowingFilter)(incomingQuery)
  }
} 
Example 53
Source File: NaptimePaginatedResourceFieldTest.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari.graphql.schema

import org.coursera.naptime.ResourceName
import org.coursera.naptime.ari.graphql.Models
import org.coursera.naptime.ari.graphql.SangriaGraphQlContext
import org.coursera.naptime.ari.graphql.helpers.ArgumentBuilder
import org.junit.Test
import org.mockito.Mockito.when
import org.scalatest.junit.AssertionsForJUnit
import org.scalatest.mockito.MockitoSugar

import scala.concurrent.ExecutionContext

class NaptimePaginatedResourceFieldTest extends AssertionsForJUnit with MockitoSugar {

  val fieldName = "relatedIds"
  val resourceName = ResourceName("courses", 1)
  val context = SangriaGraphQlContext(null, null, ExecutionContext.global, debugMode = false)

  private[this] val schemaMetadata = mock[SchemaMetadata]
  private[this] val resource = Models.courseResource
  when(schemaMetadata.getResourceOpt(resourceName)).thenReturn(Some(resource))
  when(schemaMetadata.getSchema(resource)).thenReturn(Some(null))

  @Test
  def computeComplexity(): Unit = {
    val field = NaptimePaginatedResourceField.build(
      schemaMetadata,
      resourceName,
      fieldName,
      None,
      None,
      List.empty)

    val argDefinitions = NaptimePaginationField.paginationArguments

    val limitTen = field.right.get.complexity.get
      .apply(context, ArgumentBuilder.buildArgs(argDefinitions, Map("limit" -> Some(10))), 1)
    assert(limitTen === 1 * NaptimePaginatedResourceField.COMPLEXITY_COST * 1)

    val limitFifty = field.right.get.complexity.get
      .apply(context, ArgumentBuilder.buildArgs(argDefinitions, Map("limit" -> Some(50))), 1)
    assert(limitFifty === 5 * NaptimePaginatedResourceField.COMPLEXITY_COST * 1)

    val limitZero = field.right.get.complexity.get
      .apply(context, ArgumentBuilder.buildArgs(argDefinitions, Map("limit" -> Some(1))), 1)
    assert(limitZero === 1 * NaptimePaginatedResourceField.COMPLEXITY_COST * 1)

    val childScoreFive = field.right.get.complexity.get
      .apply(context, ArgumentBuilder.buildArgs(argDefinitions, Map("limit" -> Some(1))), 5)
    assert(childScoreFive === 1 * NaptimePaginatedResourceField.COMPLEXITY_COST * 5)

  }

} 
Example 54
Source File: RestContextTest.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime

import org.junit.Test
import play.api.i18n.Lang
import play.api.mvc.Request
import org.mockito.Mockito.when
import org.scalatest.junit.AssertionsForJUnit
import org.scalatest.mockito.MockitoSugar

class RestContextTest extends AssertionsForJUnit with MockitoSugar {

  private[this] def makeContext(languagePreferences: Seq[Lang]): RestContext[Unit, Unit] = {
    val mockRequest = mock[Request[Unit]]
    val restContext = new RestContext((), (), mockRequest, null, null, null)
    when(mockRequest.acceptLanguages).thenReturn(languagePreferences)
    restContext
  }

  def test(
      requestLanguages: Seq[Lang],
      availableLanguages: Set[Lang],
      defaultLanguage: Lang,
      expected: Lang): Unit = {
    val restContext = makeContext(requestLanguages)
    assert(restContext.selectLanguage(availableLanguages, defaultLanguage) === expected)
  }

  @Test
  def basicLanguage(): Unit = {
    test(
      requestLanguages = Seq(Lang("en")),
      availableLanguages = Set(Lang("fr"), Lang("en")),
      defaultLanguage = Lang("en"),
      expected = Lang("en"))
  }

  @Test
  def defaultFallback(): Unit = {
    test(
      requestLanguages = Seq(Lang("zh")),
      availableLanguages = Set(Lang("fr"), Lang("en")),
      defaultLanguage = Lang("en"),
      expected = Lang("en"))
  }

  @Test
  def choosePreferred(): Unit = {
    test(
      requestLanguages = Seq(Lang("zh"), Lang("fr"), Lang("en")),
      availableLanguages = Set(Lang("fr"), Lang("en")),
      defaultLanguage = Lang("en"),
      expected = Lang("fr"))
  }
} 
Example 55
Source File: NaptimePlayRouterTest.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.router2

import akka.util.ByteString
import com.google.inject.Injector
import org.coursera.naptime.resources.RootResource
import org.coursera.naptime.schema.Handler
import org.coursera.naptime.schema.HandlerKind
import org.coursera.naptime.schema.Parameter
import org.coursera.naptime.schema.Resource
import org.coursera.naptime.schema.ResourceKind
import org.junit.Test
import org.mockito.Mockito.when
import org.mockito.Matchers.any
import org.scalatest.junit.AssertionsForJUnit
import org.scalatest.mockito.MockitoSugar
import play.api.libs.streams.Accumulator
import play.api.mvc.EssentialAction
import play.api.mvc.RequestHeader
import play.api.mvc.RequestTaggingHandler
import play.api.mvc.Result
import play.api.test.FakeRequest

class NaptimePlayRouterTest extends AssertionsForJUnit with MockitoSugar {
  object FakeHandler extends  EssentialAction with RequestTaggingHandler {
    override def tagRequest(request: RequestHeader): RequestHeader = request

    override def apply(v1: RequestHeader): Accumulator[ByteString, Result] = ???
  }

  val resourceSchema = Resource(
    kind = ResourceKind.COLLECTION,
    name = "fakeResource",
    version = Some(1L),
    parentClass = Some(classOf[RootResource].getName),
    keyType = "java.lang.String",
    valueType = "FakeModel",
    mergedType = "FakeResourceModel",
    handlers = List(
      Handler(
        kind = HandlerKind.GET,
        name = "get",
        parameters =
          List(Parameter(name = "id", `type` = "String", attributes = List.empty, default = None)),
        inputBodyType = None,
        customOutputBodyType = None,
        attributes = List.empty)),
    className = "org.coursera.naptime.FakeResource",
    attributes = List.empty)

  val resourceRouter = mock[ResourceRouter]
  val resourceRouterBuilder = mock[ResourceRouterBuilder]
  when(resourceRouterBuilder.build(any())).thenReturn(resourceRouter)
  when(resourceRouterBuilder.schema).thenReturn(resourceSchema)

  val injector = mock[Injector]
  val naptimeRoutes = NaptimeRoutes(injector, Set(resourceRouterBuilder))
  val router = new NaptimePlayRouter(naptimeRoutes)

  @Test
  def simpleRouting(): Unit = {
    when(resourceRouter.routeRequest(any(), any())).thenReturn(Some(FakeHandler))
    val handler = router.handlerFor(FakeRequest())
    assert(handler.isDefined)
  }

  @Test
  def simpleRoutingNothing(): Unit = {
    when(resourceRouter.routeRequest(any(), any())).thenReturn(None)
    val handler = router.handlerFor(FakeRequest())
    assert(handler.isEmpty)
  }

  @Test
  def generateDocumentation(): Unit = {
    val documentation = router.documentation
    assert(1 === documentation.length)
    assert(
      (
        "GET --- GET",
        "/fakeResource.v1/$id",
        "[NAPTIME] org.coursera.naptime.FakeResource.get(id: String)") ===
        documentation.head)
  }
} 
Example 56
Source File: ConsumerExtensionsSpec.scala    From embedded-kafka   with MIT License 5 votes vote down vote up
package net.manub.embeddedkafka

import net.manub.embeddedkafka.Codecs.stringValueCrDecoder
import net.manub.embeddedkafka.ConsumerExtensions._
import org.apache.kafka.clients.consumer.{
  ConsumerRecord,
  ConsumerRecords,
  KafkaConsumer
}
import org.apache.kafka.common.TopicPartition
import org.mockito.Mockito.{times, verify, when}
import org.scalatestplus.mockito.MockitoSugar

import scala.jdk.CollectionConverters._
import scala.concurrent.duration._

class ConsumerExtensionsSpec
    extends EmbeddedKafkaSpecSupport
    with MockitoSugar {

  "consumeLazily" should {
    "retry to get messages with the configured maximum number of attempts when poll fails" in {
      implicit val retryConf: ConsumerRetryConfig =
        ConsumerRetryConfig(2, 1.millis)

      val consumer = mock[KafkaConsumer[String, String]]
      val consumerRecords =
        new ConsumerRecords[String, String](
          Map
            .empty[TopicPartition, java.util.List[
              ConsumerRecord[String, String]
            ]]
            .asJava
        )

      when(consumer.poll(duration2JavaDuration(retryConf.poll)))
        .thenReturn(consumerRecords)

      consumer.consumeLazily[String]("topic")

      verify(consumer, times(retryConf.maximumAttempts))
        .poll(duration2JavaDuration(retryConf.poll))
    }

    "not retry to get messages with the configured maximum number of attempts when poll succeeds" in {
      implicit val retryConf: ConsumerRetryConfig =
        ConsumerRetryConfig(2, 1.millis)

      val consumer       = mock[KafkaConsumer[String, String]]
      val consumerRecord = mock[ConsumerRecord[String, String]]
      val consumerRecords = new ConsumerRecords[String, String](
        Map[TopicPartition, java.util.List[ConsumerRecord[String, String]]](
          new TopicPartition("topic", 1) -> List(consumerRecord).asJava
        ).asJava
      )

      when(consumer.poll(duration2JavaDuration(retryConf.poll)))
        .thenReturn(consumerRecords)

      consumer.consumeLazily[String]("topic")

      verify(consumer).poll(duration2JavaDuration(retryConf.poll))
    }

    "poll to get messages with the configured poll timeout" in {
      implicit val retryConf: ConsumerRetryConfig =
        ConsumerRetryConfig(1, 10.millis)

      val consumer = mock[KafkaConsumer[String, String]]
      val consumerRecords =
        new ConsumerRecords[String, String](
          Map
            .empty[TopicPartition, java.util.List[
              ConsumerRecord[String, String]
            ]]
            .asJava
        )

      when(consumer.poll(duration2JavaDuration(retryConf.poll)))
        .thenReturn(consumerRecords)

      consumer.consumeLazily[String]("topic")

      verify(consumer).poll(duration2JavaDuration(retryConf.poll))
    }
  }
} 
Example 57
Source File: RawStageTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.test.stage

import akka.actor.ActorSystem
import akka.testkit.TestKit
import com.stratio.sparta.driver.stage.{LogError, RawDataStage}
import com.stratio.sparta.sdk.pipeline.autoCalculations.AutoCalculatedField
import com.stratio.sparta.sdk.properties.JsoneyString
import com.stratio.sparta.serving.core.models.policy.writer.{AutoCalculatedFieldModel, WriterModel}
import com.stratio.sparta.serving.core.models.policy.{PolicyModel, RawDataModel}
import org.junit.runner.RunWith
import org.mockito.Mockito.when
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FlatSpecLike, ShouldMatchers}

@RunWith(classOf[JUnitRunner])
class RawStageTest
  extends TestKit(ActorSystem("RawStageTest"))
    with FlatSpecLike with ShouldMatchers with MockitoSugar {

  case class TestRawData(policy: PolicyModel) extends RawDataStage with LogError

  def mockPolicy: PolicyModel = {
    val policy = mock[PolicyModel]
    when(policy.id).thenReturn(Some("id"))
    policy
  }

  "rawDataStage" should "Generate a raw data" in {
    val field = "field"
    val timeField = "time"
    val tableName = Some("table")
    val outputs = Seq("output")
    val partitionBy = Some("field")
    val autocalculateFields = Seq(AutoCalculatedFieldModel())
    val configuration = Map.empty[String, JsoneyString]

    val policy = mockPolicy
    val rawData = mock[RawDataModel]
    val writerModel = mock[WriterModel]

    when(policy.rawData).thenReturn(Some(rawData))
    when(rawData.dataField).thenReturn(field)
    when(rawData.timeField).thenReturn(timeField)
    when(rawData.writer).thenReturn(writerModel)
    when(writerModel.tableName).thenReturn(tableName)
    when(writerModel.outputs).thenReturn(outputs)
    when(writerModel.partitionBy).thenReturn(partitionBy)
    when(writerModel.autoCalculatedFields).thenReturn(autocalculateFields)
    when(rawData.configuration).thenReturn(configuration)

    val result = TestRawData(policy).rawDataStage()

    result.timeField should be(timeField)
    result.dataField should be(field)
    result.writerOptions.tableName should be(tableName)
    result.writerOptions.partitionBy should be(partitionBy)
    result.configuration should be(configuration)
    result.writerOptions.outputs should be(outputs)
  }

  "rawDataStage" should "Fail with bad table name" in {
    val field = "field"
    val timeField = "time"
    val tableName = None
    val outputs = Seq("output")
    val partitionBy = Some("field")
    val configuration = Map.empty[String, JsoneyString]

    val policy = mockPolicy
    val rawData = mock[RawDataModel]
    val writerModel = mock[WriterModel]

    when(policy.rawData).thenReturn(Some(rawData))
    when(rawData.dataField).thenReturn(field)
    when(rawData.timeField).thenReturn(timeField)
    when(rawData.writer).thenReturn(writerModel)
    when(writerModel.tableName).thenReturn(tableName)
    when(writerModel.outputs).thenReturn(outputs)
    when(writerModel.partitionBy).thenReturn(partitionBy)
    when(rawData.configuration).thenReturn(configuration)


    the[IllegalArgumentException] thrownBy {
      TestRawData(policy).rawDataStage()
    } should have message "Something gone wrong saving the raw data. Please re-check the policy."
  }

} 
Example 58
Source File: StreamAppSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.dsl.scalaapi

import akka.actor.ActorSystem
import org.apache.gearpump.cluster.TestUtil
import org.apache.gearpump.cluster.client.ClientContext
import org.apache.gearpump.streaming.dsl.scalaapi
import org.apache.gearpump.streaming.partitioner.PartitionerDescription
import org.apache.gearpump.streaming.source.DataSourceTask
import org.apache.gearpump.streaming.{ProcessorDescription, StreamApplication}
import org.apache.gearpump.util.Graph
import org.mockito.Mockito.when
import org.scalatest._
import org.scalatest.mock.MockitoSugar

import scala.concurrent.Await
import scala.concurrent.duration.Duration
class StreamAppSpec extends FlatSpec with Matchers with BeforeAndAfterAll with MockitoSugar {

  implicit var system: ActorSystem = _

  override def beforeAll(): Unit = {
    system = ActorSystem("test", TestUtil.DEFAULT_CONFIG)
  }

  override def afterAll(): Unit = {
    system.terminate()
    Await.result(system.whenTerminated, Duration.Inf)
  }

  it should "be able to generate multiple new streams" in {
    val context: ClientContext = mock[ClientContext]
    when(context.system).thenReturn(system)

    val dsl = StreamApp("dsl", context)
    dsl.source(List("A"), 2, "A") shouldBe a [scalaapi.Stream[_]]
    dsl.source(List("B"), 3, "B") shouldBe a [scalaapi.Stream[_]]

    val application = dsl.plan()
    application shouldBe a [StreamApplication]
    application.name shouldBe "dsl"
    val dag = application.userConfig
      .getValue[Graph[ProcessorDescription, PartitionerDescription]](StreamApplication.DAG).get
    dag.getVertices.size shouldBe 2
    dag.getVertices.foreach { processor =>
      processor.taskClass shouldBe classOf[DataSourceTask[_, _]].getName
      if (processor.description == "A") {
        processor.parallelism shouldBe 2
      } else if (processor.description == "B") {
        processor.parallelism shouldBe 3
      } else {
        fail(s"undefined source ${processor.description}")
      }
    }
  }
} 
Example 59
Source File: TransformTaskSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.dsl.task

import java.time.Instant

import org.apache.gearpump.Message
import org.apache.gearpump.cluster.UserConfig
import org.apache.gearpump.streaming.MockUtil
import org.apache.gearpump.streaming.dsl.window.impl.{TimestampedValue, TriggeredOutputs, StreamingOperator}
import org.mockito.Mockito.{verify, when}
import org.scalacheck.Gen
import org.scalatest.{Matchers, PropSpec}
import org.scalatest.mock.MockitoSugar
import org.scalatest.prop.PropertyChecks

class TransformTaskSpec extends PropSpec with PropertyChecks with Matchers with MockitoSugar {

  property("MergeTask should trigger on watermark") {
    val longGen = Gen.chooseNum[Long](1L, 1000L)
    val watermarkGen = longGen.map(Instant.ofEpochMilli)

    forAll(watermarkGen) { (watermark: Instant) =>
      val windowRunner = mock[StreamingOperator[Any, Any]]
      val context = MockUtil.mockTaskContext
      val config = UserConfig.empty
      val task = new TransformTask[Any, Any](windowRunner, context, config)
      val time = watermark.minusMillis(1L)
      val value: Any = time
      val message = Message(value, time)

      task.onNext(message)
      verify(windowRunner).foreach(TimestampedValue(value, time))

      when(windowRunner.trigger(watermark)).thenReturn(
        TriggeredOutputs(Some(TimestampedValue(value, time)), watermark))
      task.onWatermarkProgress(watermark)
      verify(context).output(message)
      verify(context).updateWatermark(watermark)
    }
  }

} 
Example 60
Source File: SessionHeartbeatSpec.scala    From incubator-livy   with Apache License 2.0 5 votes vote down vote up
package org.apache.livy.server.interactive

import scala.concurrent.duration._
import scala.concurrent.Future
import scala.language.postfixOps

import org.mockito.Mockito.{never, verify, when}
import org.scalatest.{FunSpec, Matchers}
import org.scalatest.concurrent.Eventually._
import org.scalatest.mock.MockitoSugar.mock

import org.apache.livy.LivyConf
import org.apache.livy.server.recovery.SessionStore
import org.apache.livy.sessions.{Session, SessionManager}
import org.apache.livy.sessions.Session.RecoveryMetadata

class SessionHeartbeatSpec extends FunSpec with Matchers {
  describe("SessionHeartbeat") {
    class TestHeartbeat(override val heartbeatTimeout: FiniteDuration) extends SessionHeartbeat {}

    it("should not expire if heartbeat was never called.") {
      val t = new TestHeartbeat(Duration.Zero)
      t.heartbeatExpired shouldBe false
    }

    it("should expire if time has elapsed.") {
      val t = new TestHeartbeat(Duration.fromNanos(1))
      t.heartbeat()
      eventually(timeout(2 nano), interval(1 nano)) {
        t.heartbeatExpired shouldBe true
      }
    }

    it("should not expire if time hasn't elapsed.") {
      val t = new TestHeartbeat(Duration.create(1, DAYS))
      t.heartbeat()
      t.heartbeatExpired shouldBe false
    }
  }

  describe("SessionHeartbeatWatchdog") {
    abstract class TestSession
      extends Session(0, None, null, null) with SessionHeartbeat {}
    class TestWatchdog(conf: LivyConf)
      extends SessionManager[TestSession, RecoveryMetadata](
        conf,
        { _ => assert(false).asInstanceOf[TestSession] },
        mock[SessionStore],
        "test",
        Some(Seq.empty))
        with SessionHeartbeatWatchdog[TestSession, RecoveryMetadata] {}

    it("should delete only expired sessions") {
      val expiredSession: TestSession = mock[TestSession]
      when(expiredSession.id).thenReturn(0)
      when(expiredSession.name).thenReturn(None)
      when(expiredSession.heartbeatExpired).thenReturn(true)
      when(expiredSession.stop()).thenReturn(Future.successful(()))
      when(expiredSession.lastActivity).thenReturn(System.nanoTime())

      val nonExpiredSession: TestSession = mock[TestSession]
      when(nonExpiredSession.id).thenReturn(1)
      when(nonExpiredSession.name).thenReturn(None)
      when(nonExpiredSession.heartbeatExpired).thenReturn(false)
      when(nonExpiredSession.stop()).thenReturn(Future.successful(()))
      when(nonExpiredSession.lastActivity).thenReturn(System.nanoTime())

      val n = new TestWatchdog(new LivyConf())

      n.register(expiredSession)
      n.register(nonExpiredSession)
      n.deleteExpiredSessions()

      verify(expiredSession).stop()
      verify(nonExpiredSession, never).stop()
    }
  }
} 
Example 61
Source File: MasterWebUISuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 62
Source File: LogPageSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.worker.Worker

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val worker = mock(classOf[Worker])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    when(webui.worker).thenReturn(worker)
    when(worker.conf).thenReturn(new SparkConf())
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 63
Source File: StagePageSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.exec.ExecutorsListener
import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener

class StagePageSuite extends SparkFunSuite with LocalSparkContext {

  private val peakExecutionMemory = 10

  test("peak execution memory should displayed") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    val targetString = "peak execution memory"
    assert(html.contains(targetString))
  }

  test("SPARK-10543: peak execution memory should be per-task rather than cumulative") {
    val conf = new SparkConf(false)
    val html = renderStagePage(conf).toString().toLowerCase
    // verify min/25/50/75/max show task value not cumulative values
    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
  }

  
  private def renderStagePage(conf: SparkConf): Seq[Node] = {
    val jobListener = new JobProgressListener(conf)
    val graphListener = new RDDOperationGraphListener(conf)
    val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf)
    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
    val request = mock(classOf[HttpServletRequest])
    when(tab.conf).thenReturn(conf)
    when(tab.progressListener).thenReturn(jobListener)
    when(tab.operationGraphListener).thenReturn(graphListener)
    when(tab.executorsListener).thenReturn(executorsListener)
    when(tab.appName).thenReturn("testing")
    when(tab.headerTabs).thenReturn(Seq.empty)
    when(request.getParameter("id")).thenReturn("0")
    when(request.getParameter("attempt")).thenReturn("0")
    val page = new StagePage(tab)

    // Simulate a stage in job progress listener
    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
    // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness
    (1 to 2).foreach {
      taskId =>
        val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
        jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
        jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
        taskInfo.markFinished(TaskState.FINISHED)
        val taskMetrics = TaskMetrics.empty
        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
        jobListener.onTaskEnd(
          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
    }
    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
    page.render(request)
  }

} 
Example 64
Source File: HistoryServerSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.history

import javax.servlet.http.HttpServletRequest

import scala.collection.mutable

import org.apache.hadoop.fs.Path
import org.mockito.Mockito.{when}
import org.scalatest.FunSuite
import org.scalatest.Matchers
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ui.SparkUI

class HistoryServerSuite extends FunSuite with Matchers with MockitoSugar {

  test("generate history page with relative links") {
    val historyServer = mock[HistoryServer]
    val request = mock[HttpServletRequest]
    val ui = mock[SparkUI]
    val link = "/history/app1"
    val info = new ApplicationHistoryInfo("app1", "app1", 0, 2, 1, "xxx", true)
    when(historyServer.getApplicationList()).thenReturn(Seq(info))
    when(ui.basePath).thenReturn(link)
    when(historyServer.getProviderConfig()).thenReturn(Map[String, String]())
    val page = new HistoryPage(historyServer)

    //when
    val response = page.render(request)

    //then
    val links = response \\ "a"
    val justHrefs = for {
      l <- links
      attrs <- l.attribute("href")
    } yield (attrs.toString)
    justHrefs should contain(link)
  }
} 
Example 65
Source File: DiskBlockManagerSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

import org.apache.spark.SparkConf
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
}