scala.reflect.io.File Scala Examples

The following examples show how to use scala.reflect.io.File. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SystemSpec.scala    From zio   with Apache License 2.0 5 votes vote down vote up
package zio
package system

import scala.reflect.io.File

import zio.test.Assertion._
import zio.test._
import zio.test.environment.live

object SystemSpec extends ZIOBaseSpec {

  def spec = suite("SystemSpec")(
    suite("Fetch an environment variable and check that")(
      testM("If it exists, return a reasonable value") {
        assertM(live(system.env("PATH")))(isSome(containsString(File.separator + "bin")))
      },
      testM("If it does not exist, return None") {
        assertM(live(system.env("QWERTY")))(isNone)
      }
    ),
    suite("Fetch all environment variables and check that")(
      testM("If it exists, return a reasonable value") {
        assertM(live(system.envs.map(_.get("PATH"))))(isSome(containsString(File.separator + "bin")))
      },
      testM("If it does not exist, return None") {
        assertM(live(system.envs.map(_.get("QWERTY"))))(isNone)
      }
    ),
    suite("Fetch all VM properties and check that")(
      testM("If it exists, return a reasonable value") {
        assertM(live(properties.map(_.get("java.vm.name"))))(isSome(containsString("VM")))
      },
      testM("If it does not exist, return None") {
        assertM(live(properties.map(_.get("qwerty"))))(isNone)
      }
    ),
    suite("Fetch a VM property and check that")(
      testM("If it exists, return a reasonable value") {
        assertM(live(property("java.vm.name")))(isSome(containsString("VM")))
      },
      testM("If it does not exist, return None") {
        assertM(live(property("qwerty")))(isNone)
      }
    ),
    suite("Fetch the system's line separator and check that")(
      testM("it is identical to System.lineSeparator") {
        assertM(live(lineSeparator))(equalTo(java.lang.System.lineSeparator))
      }
    )
  )
} 
Example 2
Source File: FileOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta

import java.sql.Timestamp
import java.util.UUID

import com.github.nscala_time.time.Imports._
import com.stratio.sparta.sdk.pipeline.output.{Output, OutputFormatEnum, SaveModeEnum}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest._

import scala.reflect.io.File


class FileOutputIT extends FlatSpec with ShouldMatchers with BeforeAndAfterAll {
  self: FlatSpec =>

  @transient var sc: SparkContext = _

  override def beforeAll {
    Logger.getRootLogger.setLevel(Level.ERROR)
    sc = FileOutputIT.getNewLocalSparkContext(1, "test")
  }

  override def afterAll {
    sc.stop()
    System.clearProperty("spark.driver.port")
  }

  trait CommonValues {

    val sqlContext = SQLContext.getOrCreate(sc)

    import sqlContext.implicits._

    val time = new Timestamp(DateTime.now.getMillis)

    val data =
      sc.parallelize(Seq(Person("Kevin", 18, time), Person("Kira", 21, time), Person("Ariadne", 26, time))).toDF

    val tmpPath: String = s"/tmp/sparta-test/${UUID.randomUUID().toString}"
  }

  trait WithEventData extends CommonValues {
    val properties = Map("path" -> tmpPath, "createDifferentFiles" -> "false")
    val output = new FileOutput("file-test", properties)
  }

  "FileOutputIT" should "save a dataframe" in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TimeDimensionKey -> "minute", Output.TableNameKey -> "person"))

    val source = new java.io.File(tmpPath).listFiles()
    val read = sqlContext.read.json(tmpPath).toDF
    read.count shouldBe(3)
    File("/tmp/sparta-test").deleteRecursively
  }
}

object FileOutputIT {

  def getNewLocalSparkContext(numExecutors: Int = 1, title: String): SparkContext = {
    val conf = new SparkConf().setMaster(s"local[$numExecutors]").setAppName(title)
    SparkContext.getOrCreate(conf)
  }
}

case class Person(name: String, age: Int, minute: Timestamp) extends Serializable 
Example 3
Source File: AvroOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.avro

import java.sql.Timestamp
import java.time.Instant

import com.databricks.spark.avro._
import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.reflect.io.File
import scala.util.Random


@RunWith(classOf[JUnitRunner])
class AvroOutputIT extends TemporalSparkContext with Matchers {

  trait CommonValues {
    val tmpPath: String = File.makeTemp().name
    val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
    val schema = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType),
      StructField("minute", LongType)
    ))

    val data =
      sparkSession.createDataFrame(sc.parallelize(Seq(
        Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime)
      )), schema)
  }

  trait WithEventData extends CommonValues {
    val properties = Map("path" -> tmpPath)
    val output = new AvroOutput("avro-test", properties)
  }


  "AvroOutput" should "throw an exception when path is not present" in {
    an[Exception] should be thrownBy new AvroOutput("avro-test", Map.empty)
  }

  it should "throw an exception when empty path " in {
    an[Exception] should be thrownBy new AvroOutput("avro-test", Map("path" -> "    "))
  }

  it should "save a dataframe " in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person"))
    val read = sparkSession.read.avro(s"$tmpPath/person")
    read.count should be(3)
    read should be eq data
    File(tmpPath).deleteRecursively
    File("spark-warehouse").deleteRecursively
  }

} 
Example 4
Source File: CsvOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.csv

import java.sql.Timestamp
import java.time.Instant

import com.databricks.spark.avro._
import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.reflect.io.File
import scala.util.Random


@RunWith(classOf[JUnitRunner])
class CsvOutputIT extends TemporalSparkContext with Matchers {

  trait CommonValues {
    val tmpPath: String = File.makeTemp().name
    val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
    val schema = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType),
      StructField("minute", LongType)
    ))

    val data =
      sparkSession.createDataFrame(sc.parallelize(Seq(
        Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime)
      )), schema)
  }

  trait WithEventData extends CommonValues {
    val properties = Map("path" -> tmpPath)
    val output = new CsvOutput("csv-test", properties)
  }


  "CsvOutput" should "throw an exception when path is not present" in {
    an[Exception] should be thrownBy new CsvOutput("csv-test", Map.empty)
  }

  it should "throw an exception when empty path " in {
    an[Exception] should be thrownBy new CsvOutput("csv-test", Map("path" -> "    "))
  }

  it should "save a dataframe " in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person"))
    val read = sparkSession.read.csv(s"$tmpPath/person.csv")
    read.count should be(3)
    read should be eq data
    File(tmpPath).deleteRecursively
    File("spark-warehouse").deleteRecursively
  }

} 
Example 5
Source File: NLP4LMainGenericRunner.scala    From attic-nlp4l   with Apache License 2.0 5 votes vote down vote up
package org.nlp4l.repl



      if (isE) {
        ScriptRunner.runCommand(settings, combinedCode, thingToRun +: command.arguments)
      }
      else runTarget() match {
        case Left(ex) => errorFn("", Some(ex))  // there must be a useful message of hope to offer here
        case Right(b) => b
      }
    }

    if (!command.ok)
      errorFn(f"%n$shortUsageMsg")
    else if (shouldStopWithInfo)
      errorFn(command getInfoMessage sampleCompiler, isFailure = false)
    else
      run()
  }
}

object NLP4LMainGenericRunner extends NLP4LMainGenericRunner {
  def main(args: Array[String]): Unit = {
    val conf = System.getProperty("nlp4l.conf")
    new sys.SystemProperties += ("scala.repl.autoruncode" -> conf)
    if (!process(args)) sys.exit(1)
  }
} 
Example 6
Source File: ReportPlugin.scala    From AppCrawler   with Apache License 2.0 5 votes vote down vote up
package com.testerhome.appcrawler.plugin

import java.io

import com.testerhome.appcrawler.{Report, URIElement}
import com.testerhome.appcrawler._
import org.scalatest.FunSuite
import org.scalatest.tools.Runner
import sun.misc.{Signal, SignalHandler}

import scala.collection.mutable.ListBuffer
import scala.reflect.io.File


class ReportPlugin extends Plugin with Report {
  var lastSize=0
  override def start(): Unit ={
    reportPath=new java.io.File(getCrawler().conf.resultDir).getCanonicalPath
    log.info(s"reportPath=${reportPath}")
    val tmpDir=new io.File(s"${reportPath}/tmp/")
    if(tmpDir.exists()==false){
      log.info(s"create ${reportPath}/tmp/ directory")
      tmpDir.mkdir()
    }
  }

  override def stop(): Unit ={
    generateReport()
  }

  override def afterElementAction(element: URIElement): Unit ={
    val count=getCrawler().store.clickedElementsList.length
    log.info(s"clickedElementsList size = ${count}")
    val curSize=getCrawler().store.clickedElementsList.size
    if(curSize-lastSize > curSize/10+20 ){
      log.info(s"${curSize}-${lastSize} > ${curSize}/10+10  ")
      log.info("generate test report ")
      generateReport()
    }
  }

  def generateReport(): Unit ={
    Report.saveTestCase(getCrawler().store, getCrawler().conf.resultDir)
    Report.store=getCrawler().store
    Report.runTestCase()

    lastSize=getCrawler().store.clickedElementsList.size
  }


} 
Example 7
Source File: LogPlugin.scala    From AppCrawler   with Apache License 2.0 5 votes vote down vote up
package com.testerhome.appcrawler.plugin

import java.util.logging.Level

import com.testerhome.appcrawler.driver.AppiumClient
import com.testerhome.appcrawler.{Plugin, URIElement}

import scala.collection.mutable.ListBuffer
import scala.reflect.io.File


class LogPlugin extends Plugin {
  private var logs = ListBuffer[String]()
  val driver = getCrawler().driver.asInstanceOf[AppiumClient].driver

  override def afterElementAction(element: URIElement): Unit = {
    //第一次先试验可用的log 后续就可以跳过从而加速
    if (logs.isEmpty) {
      driver.manage().logs().getAvailableLogTypes.toArray().foreach(logName => {
        log.info(s"read log=${logName.toString}")
        try {
          saveLog(logName.toString)
          logs += logName.toString
        } catch {
          case ex: Exception => log.warn(s"log=${logName.toString} not exist")
        }
      })
    }
    if(getCrawler().getElementAction()!="skip") {
      logs.foreach(log => {
        saveLog(log)
      })
    }
  }

  def saveLog(logName:String): Unit ={
    log.info(s"read log=${logName.toString}")
    val logMessage = driver.manage().logs.get(logName.toString).filter(Level.ALL).toArray()
    log.info(s"log=${logName} size=${logMessage.size}")
    if (logMessage.size > 0) {
      val fileName = getCrawler().getBasePathName()+".log"
      log.info(s"save ${logName} to $fileName")
      File(fileName).writeAll(logMessage.mkString("\n"))
      log.info(s"save ${logName} end")
    }
  }


  override def afterUrlRefresh(url: String): Unit = {

  }
  override def stop(): Unit ={
    logs.foreach(log => {
      saveLog(log)
    })
  }

} 
Example 8
Source File: Report.scala    From AppCrawler   with Apache License 2.0 5 votes vote down vote up
package com.testerhome.appcrawler

import org.apache.commons.io.FileUtils
import org.scalatest.tools.Runner

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.io.{Source, Codec}
import scala.reflect.io.File
import collection.JavaConversions._


    log.info(s"run ${cmdArgs.mkString(" ")}")
    Runner.run(cmdArgs)
    changeTitle()
  }

  def changeTitle(title:String=Report.title): Unit ={
    val originTitle="ScalaTest Results"
    val indexFile=reportPath+"/index.html"
    val newContent=Source.fromFile(indexFile).mkString.replace(originTitle, title)
    scala.reflect.io.File(indexFile).writeAll(newContent)
  }

}

object Report extends Report{
  var showCancel=false
  var title="AppCrawler"
  var master=""
  var candidate=""
  var reportDir=""
  var store=new URIElementStore


  def loadResult(elementsFile: String): URIElementStore ={
    DataObject.fromYaml[URIElementStore](Source.fromFile(elementsFile).mkString)
  }
} 
Example 9
Source File: TestOCR.scala    From AppCrawler   with Apache License 2.0 5 votes vote down vote up
class TestOCR extends FunSuite{

  test("test ocr"){
    val api=new Tesseract()
    api.setDatapath("/Users/seveniruby/Downloads/")
    api.setLanguage("eng+chi_sim")
    val img=new java.io.File("/Users/seveniruby/temp/google-test7.png")
    val imgFile=ImageIO.read(img)
    val graph=imgFile.createGraphics()
    graph.setStroke(new BasicStroke(5))

    val result=api.doOCR(img)

    val words=api.getWords(imgFile, TessPageIteratorLevel.RIL_WORD).toList
    words.foreach(word=>{

      val box=word.getBoundingBox
      val x=box.getX.toInt
      val y=box.getY.toInt
      val w=box.getWidth.toInt
      val h=box.getHeight.toInt

      graph.drawRect(x, y, w, h)
      graph.drawString(word.getText, x, y)

      println(word.getBoundingBox)
      println(word.getText)
    })
    graph.dispose()
    ImageIO.write(imgFile, "png", new java.io.File(s"${img}.mark.png"))



    println(result)

  }

}
*/ 
Example 10
Source File: NLPPreprocessTest.scala    From CkoocNLP   with Apache License 2.0 5 votes vote down vote up
package nlp

import com.hankcs.hanlp.utility.Predefine
import functions.clean.Cleaner
import functions.segment.Segmenter
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.junit.Test

import scala.reflect.io.File


  @Test
  def testSegmenter(): Unit = {
    val spark = SparkSession
      .builder
      .master("local[2]")
      .appName("Segment Demo")
      .getOrCreate()

    val text = Seq(
      (0, "这段文本是用来做分词测试的!This text is for test!"),
      (1, "江州市长江大桥参加长江大桥通车仪式"),
      (2, "他邀请了不少于10个明星,有:范冰冰、赵薇、周杰伦等,还有20几位商业大佬")
    )
    val sentenceData = spark.createDataFrame(text).toDF("id", "sentence")

    // 设置HanLP配置文件路径, 默认位于classpath路径中
    val path = this.getClass.getClassLoader.getResource("").getPath
    Predefine.HANLP_PROPERTIES_PATH = path + File.separator + "hanlp.properties"

    val segmenter = new Segmenter()
      .isDelEn(true)
      .isDelNum(true)
      .isAddNature(true)
      .setSegType("StandardSegment")
      .setMinTermLen(2)
      .setMinTermNum(3)
      .setInputCol("sentence")
      .setOutputCol("segmented")

    segmenter.transform(sentenceData).show(false)

    spark.stop()
  }
} 
Example 11
Source File: Configuration.scala    From toketi-iothubreact   with MIT License 5 votes vote down vote up
// Copyright (c) Microsoft. All rights reserved.

package it.helpers

import java.nio.file.{Files, Paths}

import com.microsoft.azure.eventhubs.EventHubClient
import com.typesafe.config.{Config, ConfigFactory}
import org.json4s._
import org.json4s.jackson.JsonMethods._
import scala.reflect.io.File


object Configuration {

  // JSON parser setup, brings in default date formats etc.
  implicit val formats = DefaultFormats

  private[this] val confConnPath      = "iothub-react.connection."
  private[this] val confStreamingPath = "iothub-react.streaming."

  private[this] val conf: Config = ConfigFactory.load()

  // Read-only settings
  val iotHubNamespace : String = conf.getString(confConnPath + "namespace")
  val iotHubName      : String = conf.getString(confConnPath + "name")
  val iotHubPartitions: Int    = conf.getInt(confConnPath + "partitions")
  val accessPolicy    : String = conf.getString(confConnPath + "accessPolicy")
  val accessKey       : String = conf.getString(confConnPath + "accessKey")

  // Tests can override these
  var receiverConsumerGroup: String = EventHubClient.DEFAULT_CONSUMER_GROUP_NAME
  var receiverTimeout      : Long   = conf.getDuration(confStreamingPath + "receiverTimeout").toMillis
  var receiverBatchSize    : Int    = conf.getInt(confStreamingPath + "receiverBatchSize")

  // Read devices configuration from JSON file
  private[this] lazy val devicesJsonFile                       = conf.getString(confConnPath + "devices")
  private[this] lazy val devicesJson: String                   = File(devicesJsonFile).slurp()
  private[this] lazy val devices    : Array[DeviceCredentials] = parse(devicesJson).extract[Array[DeviceCredentials]]

  def deviceCredentials(id: String): DeviceCredentials = {
    val deviceData: Option[DeviceCredentials] = devices.find(x ⇒ x.deviceId == id)
    if (deviceData == None) {
      throw new RuntimeException(s"Device '${id}' credentials not found")
    }
    deviceData.get
  }

  if (!Files.exists(Paths.get(devicesJsonFile))) {
    throw new RuntimeException("Devices credentials not found")
  }
} 
Example 12
Source File: EmbeddedKafkaConnect.scala    From ksql-jdbc-driver   with Apache License 2.0 5 votes vote down vote up
package com.github.mmolimar.ksql.jdbc.embedded

import java.util

import com.github.mmolimar.ksql.jdbc.utils.TestUtils
import kafka.utils.Logging
import org.apache.kafka.common.utils.Time
import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy
import org.apache.kafka.connect.runtime.isolation.Plugins
import org.apache.kafka.connect.runtime.rest.RestServer
import org.apache.kafka.connect.runtime.standalone.{StandaloneConfig, StandaloneHerder}
import org.apache.kafka.connect.runtime.{Connect, Worker, WorkerConfig}
import org.apache.kafka.connect.storage.FileOffsetBackingStore
import org.apache.kafka.connect.util.ConnectUtils

import scala.collection.JavaConverters._
import scala.reflect.io.File

class EmbeddedKafkaConnect(brokerList: String, port: Int = TestUtils.getAvailablePort) extends Logging {

  private val workerProps: util.Map[String, String] = Map[String, String](
    WorkerConfig.LISTENERS_CONFIG -> s"http://localhost:$port",
    WorkerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokerList,
    WorkerConfig.KEY_CONVERTER_CLASS_CONFIG -> "org.apache.kafka.connect.converters.ByteArrayConverter",
    WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG -> "org.apache.kafka.connect.converters.ByteArrayConverter",
    StandaloneConfig.OFFSET_STORAGE_FILE_FILENAME_CONFIG -> File.makeTemp(prefix = "connect.offsets").jfile.getAbsolutePath
  ).asJava

  private lazy val kafkaConnect: Connect = buildConnect

  def startup(): Unit = {
    info("Starting up embedded Kafka connect")

    kafkaConnect.start()

    info(s"Started embedded Kafka connect on port: $port")
  }

  def shutdown(): Unit = {
    info("Shutting down embedded Kafka Connect")

    TestUtils.swallow(kafkaConnect.stop())

    info("Stopped embedded Kafka Connect")
  }

  private def buildConnect: Connect = {
    val config = new StandaloneConfig(workerProps)
    val kafkaClusterId = ConnectUtils.lookupKafkaClusterId(config)

    val rest = new RestServer(config)
    rest.initializeServer()

    val advertisedUrl = rest.advertisedUrl
    val workerId = advertisedUrl.getHost + ":" + advertisedUrl.getPort
    val plugins = new Plugins(workerProps)
    val connectorClientConfigOverridePolicy = plugins.newPlugin(
      config.getString(WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG), config, classOf[ConnectorClientConfigOverridePolicy])
    val worker = new Worker(workerId, Time.SYSTEM, plugins, config, new FileOffsetBackingStore, connectorClientConfigOverridePolicy)
    val herder = new StandaloneHerder(worker, kafkaClusterId, connectorClientConfigOverridePolicy)

    new Connect(herder, rest)
  }

  def getPort: Int = port

  def getWorker: String = s"localhost:$port"

  def getUrl: String = s"http://localhost:$port"

  override def toString: String = {
    val sb: StringBuilder = StringBuilder.newBuilder
    sb.append("KafkaConnect{")
    sb.append("port=").append(port)
    sb.append('}')

    sb.toString
  }

} 
Example 13
Source File: CheckpointTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.extra.checkpoint

import java.nio.file.Files

import com.spotify.scio.{ContextAndArgs, ScioMetrics}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.reflect.io.File
import scala.util.Try

object CheckpointMetrics {
  def runJob(checkpointArg: String, tempLocation: String = null): (Long, Long) = {
    val elemsBefore = ScioMetrics.counter("elemsBefore")
    val elemsAfter = ScioMetrics.counter("elemsAfter")

    val (sc, args) = ContextAndArgs(
      Array(s"--checkpoint=$checkpointArg") ++
        Option(tempLocation).map(e => s"--tempLocation=$e")
    )
    sc.checkpoint(args("checkpoint")) {
      sc.parallelize(1 to 10)
        .map { x => elemsBefore.inc(); x }
    }.map { x => elemsAfter.inc(); x }
    val r = sc.run().waitUntilDone()
    (Try(r.counter(elemsBefore).committed.get).getOrElse(0), r.counter(elemsAfter).committed.get)
  }
}

class CheckpointTest extends AnyFlatSpec with Matchers {
  import CheckpointMetrics._

  "checkpoint" should "work on path" in {
    val tmpDir =
      Files.createTempDirectory("checkpoint-").resolve("checkpoint").toString
    runJob(tmpDir) shouldBe ((10L, 10L))
    runJob(tmpDir) shouldBe ((0L, 10L))
    File(tmpDir).deleteRecursively()
    runJob(tmpDir) shouldBe ((10L, 10L))
  }

  it should "work on name/file" in {
    val checkpointName = "c1"
    val tempLocation = Files.createTempDirectory("temp-location-").toString
    runJob(checkpointName, tempLocation) shouldBe ((10L, 10L))
    runJob(checkpointName, tempLocation) shouldBe ((0L, 10L))
    File(s"$tempLocation/$checkpointName").deleteRecursively()
    runJob(checkpointName, tempLocation) shouldBe ((10L, 10L))
  }
} 
Example 14
Source File: ParquetOutputIT.scala    From sparta   with Apache License 2.0 4 votes vote down vote up
package com.stratio.sparta.plugin.output.parquet

import com.github.nscala_time.time.Imports._
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.reflect.io.File

@RunWith(classOf[JUnitRunner])
class ParquetOutputIT extends FlatSpec with ShouldMatchers with BeforeAndAfterAll {
  self: FlatSpec =>

  @transient var sc: SparkContext = _

  override def beforeAll {
    Logger.getRootLogger.setLevel(Level.ERROR)
    sc = ParquetOutputIT.getNewLocalSparkContext(1, "test")
  }

  override def afterAll {
    sc.stop()
    System.clearProperty("spark.driver.port")
  }

  trait CommonValues {

    val sqlContext = SparkSession.builder().config(sc.getConf).getOrCreate()

    import sqlContext.implicits._

    val time = DateTime.now.getMillis

    val data =
      sc.parallelize(Seq(Person("Kevin", 18, time), Person("Kira", 21, time), Person("Ariadne", 26, time))).toDS().toDF

    val tmpPath: String = File.makeTemp().name
  }

  trait WithEventData extends CommonValues {

    val properties = Map("path" -> tmpPath)
    val output = new ParquetOutput("parquet-test", properties)
  }

  trait WithoutGranularity extends CommonValues {

    val datePattern = "yyyy/MM/dd"
    val properties = Map("path" -> tmpPath, "datePattern" -> datePattern)
    val output = new ParquetOutput("parquet-test", properties)
    val expectedPath = "/0"
  }

  "ParquetOutputIT" should "save a dataframe" in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person"))
    val read = sqlContext.read.parquet(s"$tmpPath/person").toDF
    read.count should be(3)
    read should be eq (data)
    File(tmpPath).deleteRecursively
    File("spark-warehouse").deleteRecursively
  }

  it should "throw an exception when path is not present" in {
    an[Exception] should be thrownBy new ParquetOutput("parquet-test", Map())
  }
}

object ParquetOutputIT {

  def getNewLocalSparkContext(numExecutors: Int = 1, title: String): SparkContext = {
    val conf = new SparkConf().setMaster(s"local[$numExecutors]").setAppName(title)
    SparkContext.getOrCreate(conf)
  }
}

case class Person(name: String, age: Int, minute: Long) extends Serializable