org.apache.spark.util.Utils Scala Examples

The following examples show how to use org.apache.spark.util.Utils. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: DataFrameExample.scala    From drizzle-spark   with Apache License 2.0 7 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

import java.io.File

import scopt.OptionParser

import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.util.Utils


object DataFrameExample {

  case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
    extends AbstractParams[Params]

  def main(args: Array[String]) {
    val defaultParams = Params()

    val parser = new OptionParser[Params]("DataFrameExample") {
      head("DataFrameExample: an example app using DataFrame for ML.")
      opt[String]("input")
        .text(s"input path to dataframe")
        .action((x, c) => c.copy(input = x))
      checkConfig { params =>
        success
      }
    }

    parser.parse(args, defaultParams) match {
      case Some(params) => run(params)
      case _ => sys.exit(1)
    }
  }

  def run(params: Params): Unit = {
    val spark = SparkSession
      .builder
      .appName(s"DataFrameExample with $params")
      .getOrCreate()

    // Load input data
    println(s"Loading LIBSVM file with UDT from ${params.input}.")
    val df: DataFrame = spark.read.format("libsvm").load(params.input).cache()
    println("Schema from LIBSVM:")
    df.printSchema()
    println(s"Loaded training data as a DataFrame with ${df.count()} records.")

    // Show statistical summary of labels.
    val labelSummary = df.describe("label")
    labelSummary.show()

    // Convert features column to an RDD of vectors.
    val features = df.select("features").rdd.map { case Row(v: Vector) => v }
    val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
      (summary, feat) => summary.add(Vectors.fromML(feat)),
      (sum1, sum2) => sum1.merge(sum2))
    println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

    // Save the records in a parquet file.
    val tmpDir = Utils.createTempDir()
    val outputDir = new File(tmpDir, "dataframe").toString
    println(s"Saving to $outputDir as Parquet file.")
    df.write.parquet(outputDir)

    // Load the records back.
    println(s"Loading Parquet file with UDT from $outputDir.")
    val newDF = spark.read.parquet(outputDir)
    println(s"Schema from Parquet:")
    newDF.printSchema()

    spark.stop()
  }
}
// scalastyle:on println 
Example 2
Source File: CommandUtils.scala    From drizzle-spark   with Apache License 2.0 7 votes vote down vote up
package org.apache.spark.deploy.worker

import java.io.{File, FileOutputStream, InputStream, IOException}

import scala.collection.JavaConverters._
import scala.collection.Map

import org.apache.spark.SecurityManager
import org.apache.spark.deploy.Command
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.WorkerCommandBuilder
import org.apache.spark.util.Utils


  def redirectStream(in: InputStream, file: File) {
    val out = new FileOutputStream(file, true)
    // TODO: It would be nice to add a shutdown hook here that explains why the output is
    //       terminating. Otherwise if the worker dies the executor logs will silently stop.
    new Thread("redirect output to " + file) {
      override def run() {
        try {
          Utils.copyStream(in, out, true)
        } catch {
          case e: IOException =>
            logInfo("Redirection to " + file + " closed: " + e.getMessage)
        }
      }
    }.start()
  }
} 
Example 3
Source File: RateController.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 4
Source File: LauncherBackend.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
package org.apache.spark.launcher

import java.net.{InetAddress, Socket}

import org.apache.spark.SPARK_VERSION
import org.apache.spark.launcher.LauncherProtocol._
import org.apache.spark.util.{ThreadUtils, Utils}


  protected def onDisconnected() : Unit = { }

  private def fireStopRequest(): Unit = {
    val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
      override def run(): Unit = Utils.tryLogNonFatalError {
        onStopRequest()
      }
    })
    thread.start()
  }

  private class BackendConnection(s: Socket) extends LauncherConnection(s) {

    override protected def handle(m: Message): Unit = m match {
      case _: Stop =>
        fireStopRequest()

      case _ =>
        throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
    }

    override def close(): Unit = {
      try {
        super.close()
      } finally {
        onDisconnected()
        _isConnected = false
      }
    }

  }

}

private object LauncherBackend {

  val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend")

} 
Example 5
Source File: StreamingTestExample.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.examples.mllib

import org.apache.spark.SparkConf
import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.util.Utils


object StreamingTestExample {

  def main(args: Array[String]) {
    if (args.length != 3) {
      // scalastyle:off println
      System.err.println(
        "Usage: StreamingTestExample " +
          "<dataDir> <batchDuration> <numBatchesTimeout>")
      // scalastyle:on println
      System.exit(1)
    }
    val dataDir = args(0)
    val batchDuration = Seconds(args(1).toLong)
    val numBatchesTimeout = args(2).toInt

    val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample")
    val ssc = new StreamingContext(conf, batchDuration)
    ssc.checkpoint {
      val dir = Utils.createTempDir()
      dir.toString
    }

    // $example on$
    val data = ssc.textFileStream(dataDir).map(line => line.split(",") match {
      case Array(label, value) => BinarySample(label.toBoolean, value.toDouble)
    })

    val streamingTest = new StreamingTest()
      .setPeacePeriod(0)
      .setWindowSize(0)
      .setTestMethod("welch")

    val out = streamingTest.registerStream(data)
    out.print()
    // $example off$

    // Stop processing if test becomes significant or we time out
    var timeoutCounter = numBatchesTimeout
    out.foreachRDD { rdd =>
      timeoutCounter -= 1
      val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _)
      if (timeoutCounter == 0 || anySignificant) rdd.context.stop()
    }

    ssc.start()
    ssc.awaitTermination()
  }
} 
Example 6
Source File: DriverSubmissionTest.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples

import scala.collection.JavaConverters._

import org.apache.spark.util.Utils


object DriverSubmissionTest {
  def main(args: Array[String]) {
    if (args.length < 1) {
      println("Usage: DriverSubmissionTest <seconds-to-sleep>")
      System.exit(0)
    }
    val numSecondsToSleep = args(0).toInt

    val env = System.getenv()
    val properties = Utils.getSystemProperties

    println("Environment variables containing SPARK_TEST:")
    env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println)

    println("System properties containing spark.test:")
    properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println)

    for (i <- 1 until numSecondsToSleep) {
      println(s"Alive for $i out of $numSecondsToSleep seconds")
      Thread.sleep(1000)
    }
  }
}
// scalastyle:on println 
Example 7
Source File: MesosClusterDispatcherArguments.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.mesos

import scala.annotation.tailrec

import org.apache.spark.SparkConf
import org.apache.spark.util.{IntParam, Utils}


private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) {
  var host = Utils.localHostName()
  var port = 7077
  var name = "Spark Cluster"
  var webUiPort = 8081
  var masterUrl: String = _
  var zookeeperUrl: Option[String] = None
  var propertiesFile: String = _

  parse(args.toList)

  propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)

  @tailrec
  private def parse(args: List[String]): Unit = args match {
    case ("--host" | "-h") :: value :: tail =>
      Utils.checkHost(value, "Please use hostname " + value)
      host = value
      parse(tail)

    case ("--port" | "-p") :: IntParam(value) :: tail =>
      port = value
      parse(tail)

    case ("--webui-port") :: IntParam(value) :: tail =>
      webUiPort = value
      parse(tail)

    case ("--zk" | "-z") :: value :: tail =>
      zookeeperUrl = Some(value)
      parse(tail)

    case ("--master" | "-m") :: value :: tail =>
      if (!value.startsWith("mesos://")) {
        // scalastyle:off println
        System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)")
        // scalastyle:on println
        System.exit(1)
      }
      masterUrl = value.stripPrefix("mesos://")
      parse(tail)

    case ("--name") :: value :: tail =>
      name = value
      parse(tail)

    case ("--properties-file") :: value :: tail =>
      propertiesFile = value
      parse(tail)

    case ("--help") :: tail =>
      printUsageAndExit(0)

    case Nil =>
      if (masterUrl == null) {
        // scalastyle:off println
        System.err.println("--master is required")
        // scalastyle:on println
        printUsageAndExit(1)
      }

    case _ =>
      printUsageAndExit(1)
  }

  private def printUsageAndExit(exitCode: Int): Unit = {
    // scalastyle:off println
    System.err.println(
      "Usage: MesosClusterDispatcher [options]\n" +
        "\n" +
        "Options:\n" +
        "  -h HOST, --host HOST    Hostname to listen on\n" +
        "  -p PORT, --port PORT    Port to listen on (default: 7077)\n" +
        "  --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" +
        "  --name NAME             Framework name to show in Mesos UI\n" +
        "  -m --master MASTER      URI for connecting to Mesos master\n" +
        "  -z --zk ZOOKEEPER       Comma delimited URLs for connecting to \n" +
        "                          Zookeeper for persistence\n" +
        "  --properties-file FILE  Path to a custom Spark properties file.\n" +
        "                          Default is conf/spark-defaults.conf.")
    // scalastyle:on println
    System.exit(exitCode)
  }
} 
Example 8
Source File: MesosClusterDispatcher.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.mesos

import java.util.concurrent.CountDownLatch

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.mesos.ui.MesosClusterUI
import org.apache.spark.deploy.rest.mesos.MesosRestServer
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.cluster.mesos._
import org.apache.spark.util.{ShutdownHookManager, Utils}


private[mesos] class MesosClusterDispatcher(
    args: MesosClusterDispatcherArguments,
    conf: SparkConf)
  extends Logging {

  private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host)
  private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase()
  logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode)

  private val engineFactory = recoveryMode match {
    case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory
    case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf)
    case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode)
  }

  private val scheduler = new MesosClusterScheduler(engineFactory, conf)

  private val server = new MesosRestServer(args.host, args.port, conf, scheduler)
  private val webUi = new MesosClusterUI(
    new SecurityManager(conf),
    args.webUiPort,
    conf,
    publicAddress,
    scheduler)

  private val shutdownLatch = new CountDownLatch(1)

  def start(): Unit = {
    webUi.bind()
    scheduler.frameworkUrl = conf.get("spark.mesos.dispatcher.webui.url", webUi.activeWebUiUrl)
    scheduler.start()
    server.start()
  }

  def awaitShutdown(): Unit = {
    shutdownLatch.await()
  }

  def stop(): Unit = {
    webUi.stop()
    server.stop()
    scheduler.stop()
    shutdownLatch.countDown()
  }
}

private[mesos] object MesosClusterDispatcher extends Logging {
  def main(args: Array[String]) {
    Utils.initDaemon(log)
    val conf = new SparkConf
    val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf)
    conf.setMaster(dispatcherArgs.masterUrl)
    conf.setAppName(dispatcherArgs.name)
    dispatcherArgs.zookeeperUrl.foreach { z =>
      conf.set("spark.deploy.recoveryMode", "ZOOKEEPER")
      conf.set("spark.deploy.zookeeper.url", z)
    }
    val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf)
    dispatcher.start()
    logDebug("Adding shutdown hook") // force eager creation of logger
    ShutdownHookManager.addShutdownHook { () =>
      logInfo("Shutdown hook is shutting down dispatcher")
      dispatcher.stop()
      dispatcher.awaitShutdown()
    }
    dispatcher.awaitShutdown()
  }
} 
Example 9
Source File: MesosClusterPersistenceEngine.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster.mesos

import scala.collection.JavaConverters._

import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
import org.apache.zookeeper.KeeperException.NoNodeException

import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkCuratorUtil
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[spark] class ZookeeperMesosClusterPersistenceEngine(
    baseDir: String,
    zk: CuratorFramework,
    conf: SparkConf)
  extends MesosClusterPersistenceEngine with Logging {
  private val WORKING_DIR =
    conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir

  SparkCuratorUtil.mkdir(zk, WORKING_DIR)

  def path(name: String): String = {
    WORKING_DIR + "/" + name
  }

  override def expunge(name: String): Unit = {
    zk.delete().forPath(path(name))
  }

  override def persist(name: String, obj: Object): Unit = {
    val serialized = Utils.serialize(obj)
    val zkPath = path(name)
    zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized)
  }

  override def fetch[T](name: String): Option[T] = {
    val zkPath = path(name)

    try {
      val fileData = zk.getData().forPath(zkPath)
      Some(Utils.deserialize[T](fileData))
    } catch {
      case e: NoNodeException => None
      case e: Exception =>
        logWarning("Exception while reading persisted file, deleting", e)
        zk.delete().forPath(zkPath)
        None
    }
  }

  override def fetchAll[T](): Iterable[T] = {
    zk.getChildren.forPath(WORKING_DIR).asScala.flatMap(fetch[T])
  }
} 
Example 10
Source File: BytecodeUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.graphx.util

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable.HashSet
import scala.language.existentials

import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor}
import org.apache.xbean.asm5.Opcodes._

import org.apache.spark.util.Utils


  private class MethodInvocationFinder(className: String, methodName: String)
    extends ClassVisitor(ASM5) {

    val methodsInvoked = new HashSet[(Class[_], String)]

    override def visitMethod(access: Int, name: String, desc: String,
                             sig: String, exceptions: Array[String]): MethodVisitor = {
      if (name == methodName) {
        new MethodVisitor(ASM5) {
          override def visitMethodInsn(
              op: Int, owner: String, name: String, desc: String, itf: Boolean) {
            if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
              if (!skipClass(owner)) {
                methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name))
              }
            }
          }
        }
      } else {
        null
      }
    }
  }
} 
Example 11
Source File: GraphLoaderSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.graphx

import java.io.File
import java.io.FileOutputStream
import java.io.OutputStreamWriter
import java.nio.charset.StandardCharsets

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils

class GraphLoaderSuite extends SparkFunSuite with LocalSparkContext {

  test("GraphLoader.edgeListFile") {
    withSpark { sc =>
      val tmpDir = Utils.createTempDir()
      val graphFile = new File(tmpDir.getAbsolutePath, "graph.txt")
      val writer = new OutputStreamWriter(new FileOutputStream(graphFile), StandardCharsets.UTF_8)
      for (i <- (1 until 101)) writer.write(s"$i 0\n")
      writer.close()
      try {
        val graph = GraphLoader.edgeListFile(sc, tmpDir.getAbsolutePath)
        val neighborAttrSums = graph.aggregateMessages[Int](
          ctx => ctx.sendToDst(ctx.srcAttr),
          _ + _)
        assert(neighborAttrSums.collect.toSet === Set((0: VertexId, 100)))
      } finally {
        Utils.deleteRecursively(tmpDir)
      }
    }
  }
} 
Example 12
Source File: BaggedPoint.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.tree.impl

import org.apache.commons.math3.distribution.PoissonDistribution

import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom


  def convertToBaggedRDD[Datum] (
      input: RDD[Datum],
      subsamplingRate: Double,
      numSubsamples: Int,
      withReplacement: Boolean,
      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
    if (withReplacement) {
      convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
    } else {
      if (numSubsamples == 1 && subsamplingRate == 1.0) {
        convertToBaggedRDDWithoutSampling(input)
      } else {
        convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
      }
    }
  }

  private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
      input: RDD[Datum],
      subsamplingRate: Double,
      numSubsamples: Int,
      seed: Long): RDD[BaggedPoint[Datum]] = {
    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
      val rng = new XORShiftRandom
      rng.setSeed(seed + partitionIndex + 1)
      instances.map { instance =>
        val subsampleWeights = new Array[Double](numSubsamples)
        var subsampleIndex = 0
        while (subsampleIndex < numSubsamples) {
          val x = rng.nextDouble()
          subsampleWeights(subsampleIndex) = {
            if (x < subsamplingRate) 1.0 else 0.0
          }
          subsampleIndex += 1
        }
        new BaggedPoint(instance, subsampleWeights)
      }
    }
  }

  private def convertToBaggedRDDSamplingWithReplacement[Datum] (
      input: RDD[Datum],
      subsample: Double,
      numSubsamples: Int,
      seed: Long): RDD[BaggedPoint[Datum]] = {
    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
      val poisson = new PoissonDistribution(subsample)
      poisson.reseedRandomGenerator(seed + partitionIndex + 1)
      instances.map { instance =>
        val subsampleWeights = new Array[Double](numSubsamples)
        var subsampleIndex = 0
        while (subsampleIndex < numSubsamples) {
          subsampleWeights(subsampleIndex) = poisson.sample()
          subsampleIndex += 1
        }
        new BaggedPoint(instance, subsampleWeights)
      }
    }
  }

  private def convertToBaggedRDDWithoutSampling[Datum] (
      input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
    input.map(datum => new BaggedPoint(datum, Array(1.0)))
  }

} 
Example 13
Source File: HashingTF.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.feature

import java.lang.{Iterable => JavaIterable}

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.unsafe.hash.Murmur3_x86_32._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils


  private[spark] def murmur3Hash(term: Any): Int = {
    term match {
      case null => seed
      case b: Boolean => hashInt(if (b) 1 else 0, seed)
      case b: Byte => hashInt(b, seed)
      case s: Short => hashInt(s, seed)
      case i: Int => hashInt(i, seed)
      case l: Long => hashLong(l, seed)
      case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
      case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
      case s: String =>
        val utf8 = UTF8String.fromString(s)
        hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
      case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " +
        s"support type ${term.getClass.getCanonicalName} of input data.")
    }
  }
} 
Example 14
Source File: HashingTFSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils

class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  import testImplicits._

  test("params") {
    ParamsSuite.checkParams(new HashingTF)
  }

  test("hashingTF") {
    val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words")
    val n = 100
    val hashingTF = new HashingTF()
      .setInputCol("words")
      .setOutputCol("features")
      .setNumFeatures(n)
    val output = hashingTF.transform(df)
    val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
    require(attrGroup.numAttributes === Some(n))
    val features = output.select("features").first().getAs[Vector](0)
    // Assume perfect hash on "a", "b", "c", and "d".
    def idx: Any => Int = murmur3FeatureIdx(n)
    val expected = Vectors.sparse(n,
      Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
    assert(features ~== expected absTol 1e-14)
  }

  test("applying binary term freqs") {
    val df = Seq((0, "a a b c c c".split(" ").toSeq)).toDF("id", "words")
    val n = 100
    val hashingTF = new HashingTF()
        .setInputCol("words")
        .setOutputCol("features")
        .setNumFeatures(n)
        .setBinary(true)
    val output = hashingTF.transform(df)
    val features = output.select("features").first().getAs[Vector](0)
    def idx: Any => Int = murmur3FeatureIdx(n)  // Assume perfect hash on input features
    val expected = Vectors.sparse(n,
      Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
    assert(features ~== expected absTol 1e-14)
  }

  test("read/write") {
    val t = new HashingTF()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setNumFeatures(10)
    testDefaultReadWrite(t)
  }

  private def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
    Utils.nonNegativeMod(MLlibHashingTF.murmur3Hash(term), numFeatures)
  }
} 
Example 15
Source File: LibSVMRelationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.source.libsvm

import java.io.File
import java.nio.charset.StandardCharsets

import com.google.common.io.Files

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.util.Utils


class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
  // Path for dataset
  var path: String = _

  override def beforeAll(): Unit = {
    super.beforeAll()
    val lines =
      """
        |1 1:1.0 3:2.0 5:3.0
        |0
        |0 2:4.0 4:5.0 6:6.0
      """.stripMargin
    val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
    val file = new File(dir, "part-00000")
    Files.write(lines, file, StandardCharsets.UTF_8)
    path = dir.toURI.toString
  }

  override def afterAll(): Unit = {
    try {
      Utils.deleteRecursively(new File(path))
    } finally {
      super.afterAll()
    }
  }

  test("select as sparse vector") {
    val df = spark.read.format("libsvm").load(path)
    assert(df.columns(0) == "label")
    assert(df.columns(1) == "features")
    val row1 = df.first()
    assert(row1.getDouble(0) == 1.0)
    val v = row1.getAs[SparseVector](1)
    assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
  }

  test("select as dense vector") {
    val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense"))
      .load(path)
    assert(df.columns(0) == "label")
    assert(df.columns(1) == "features")
    assert(df.count() == 3)
    val row1 = df.first()
    assert(row1.getDouble(0) == 1.0)
    val v = row1.getAs[DenseVector](1)
    assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
  }

  test("select a vector with specifying the longer dimension") {
    val df = spark.read.option("numFeatures", "100").format("libsvm")
      .load(path)
    val row1 = df.first()
    val v = row1.getAs[SparseVector](1)
    assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
  }

  test("write libsvm data and read it again") {
    val df = spark.read.format("libsvm").load(path)
    val tempDir2 = new File(tempDir, "read_write_test")
    val writepath = tempDir2.toURI.toString
    // TODO: Remove requirement to coalesce by supporting multiple reads.
    df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)

    val df2 = spark.read.format("libsvm").load(writepath)
    val row1 = df2.first()
    val v = row1.getAs[SparseVector](1)
    assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
  }

  test("write libsvm data failed due to invalid schema") {
    val df = spark.read.format("text").load(path)
    intercept[SparkException] {
      df.write.format("libsvm").save(path + "_2")
    }
  }

  test("select features from libsvm relation") {
    val df = spark.read.format("libsvm").load(path)
    df.select("features").rdd.map { case Row(d: Vector) => d }.first
    df.select("features").collect
  }
} 
Example 16
Source File: ChiSqSelectorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils

class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {

  

  test("ChiSqSelector transform test (sparse & dense vector)") {
    val labeledDiscreteData = sc.parallelize(
      Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
        LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
        LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
        LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
    val preFilteredData =
      Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
        LabeledPoint(1.0, Vectors.dense(Array(0.0))),
        LabeledPoint(1.0, Vectors.dense(Array(0.0))),
        LabeledPoint(2.0, Vectors.dense(Array(8.0))))
    val model = new ChiSqSelector(1).fit(labeledDiscreteData)
    val filteredData = labeledDiscreteData.map { lp =>
      LabeledPoint(lp.label, model.transform(lp.features))
    }.collect().toSet
    assert(filteredData == preFilteredData)
  }

  test("ChiSqSelector by FPR transform test (sparse & dense vector)") {
    val labeledDiscreteData = sc.parallelize(
      Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))),
        LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))),
        LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))),
        LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2)
    val preFilteredData =
      Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
        LabeledPoint(1.0, Vectors.dense(Array(4.0))),
        LabeledPoint(1.0, Vectors.dense(Array(4.0))),
        LabeledPoint(2.0, Vectors.dense(Array(9.0))))
    val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
    val filteredData = labeledDiscreteData.map { lp =>
      LabeledPoint(lp.label, model.transform(lp.features))
    }.collect().toSet
    assert(filteredData == preFilteredData)
  }

  test("model load / save") {
    val model = ChiSqSelectorSuite.createModel()
    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString
    try {
      model.save(sc, path)
      val sameModel = ChiSqSelectorModel.load(sc, path)
      ChiSqSelectorSuite.checkEqual(model, sameModel)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
}

object ChiSqSelectorSuite extends SparkFunSuite {

  def createModel(): ChiSqSelectorModel = {
    val arr = Array(1, 2, 3, 4)
    new ChiSqSelectorModel(arr)
  }

  def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = {
    assert(a.selectedFeatures.deep == b.selectedFeatures.deep)
  }
} 
Example 17
Source File: MatrixFactorizationModelSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.recommendation

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext {

  val rank = 2
  var userFeatures: RDD[(Int, Array[Double])] = _
  var prodFeatures: RDD[(Int, Array[Double])] = _

  override def beforeAll(): Unit = {
    super.beforeAll()
    userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
    prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
  }

  test("constructor") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)

    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(1, userFeatures, prodFeatures)
    }

    val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
    }

    val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
    }
  }

  test("save/load") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString
    def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
      features.mapValues(_.toSeq).collect().toSet
    }
    try {
      model.save(sc, path)
      val newModel = MatrixFactorizationModel.load(sc, path)
      assert(newModel.rank === rank)
      assert(collect(newModel.userFeatures) === collect(userFeatures))
      assert(collect(newModel.productFeatures) === collect(prodFeatures))
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }

  test("batch predict API recommendProductsForUsers") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val topK = 10
    val recommendations = model.recommendProductsForUsers(topK).collectAsMap()

    assert(recommendations(0)(0).rating ~== 17.0 relTol 1e-14)
    assert(recommendations(1)(0).rating ~== 39.0 relTol 1e-14)
  }

  test("batch predict API recommendUsersForProducts") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val topK = 10
    val recommendations = model.recommendUsersForProducts(topK).collectAsMap()

    assert(recommendations(2)(0).user == 1)
    assert(recommendations(2)(0).rating ~== 39.0 relTol 1e-14)
    assert(recommendations(2)(1).user == 0)
    assert(recommendations(2)(1).rating ~== 17.0 relTol 1e-14)
  }
} 
Example 18
Source File: MLlibTestSparkContext.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import java.io.File

import org.scalatest.Suite

import org.apache.spark.SparkContext
import org.apache.spark.ml.util.TempDirectory
import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.util.Utils

trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
  @transient var spark: SparkSession = _
  @transient var sc: SparkContext = _
  @transient var checkpointDir: String = _

  override def beforeAll() {
    super.beforeAll()
    spark = SparkSession.builder
      .master("local[2]")
      .appName("MLlibUnitTest")
      .getOrCreate()
    sc = spark.sparkContext

    checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
    sc.setCheckpointDir(checkpointDir)
  }

  override def afterAll() {
    try {
      Utils.deleteRecursively(new File(checkpointDir))
      SparkSession.clearActiveSession()
      if (spark != null) {
        spark.stop()
      }
      spark = null
    } finally {
      super.afterAll()
    }
  }

  
  protected object testImplicits extends SQLImplicits {
    protected override def _sqlContext: SQLContext = self.spark.sqlContext
  }
} 
Example 19
Source File: RidgeRegressionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.regression

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext,
  MLlibTestSparkContext}
import org.apache.spark.util.Utils

private object RidgeRegressionSuite {

  
  val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}

class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {

  def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
    predictions.zip(input).map { case (prediction, expected) =>
      (prediction - expected.label) * (prediction - expected.label)
    }.sum / predictions.size
  }

  test("ridge regression can help avoid overfitting") {

    // For small number of examples and large variance of error distribution,
    // ridge regression should give smaller generalization error that linear regression.

    val numExamples = 50
    val numFeatures = 20

    // Pick weights as random values distributed uniformly in [-0.5, 0.5]
    val random = new Random(42)
    val w = Array.fill(numFeatures)(random.nextDouble() - 0.5)

    // Use half of data for training and other half for validation
    val data = LinearDataGenerator.generateLinearInput(3.0, w, 2 * numExamples, 42, 10.0)
    val testData = data.take(numExamples)
    val validationData = data.takeRight(numExamples)

    val testRDD = sc.parallelize(testData, 2).cache()
    val validationRDD = sc.parallelize(validationData, 2).cache()

    // First run without regularization.
    val linearReg = new LinearRegressionWithSGD()
    linearReg.optimizer.setNumIterations(200)
                       .setStepSize(1.0)

    val linearModel = linearReg.run(testRDD)
    val linearErr = predictionError(
        linearModel.predict(validationRDD.map(_.features)).collect(), validationData)

    val ridgeReg = new RidgeRegressionWithSGD()
    ridgeReg.optimizer.setNumIterations(200)
                      .setRegParam(0.1)
                      .setStepSize(1.0)
    val ridgeModel = ridgeReg.run(testRDD)
    val ridgeErr = predictionError(
        ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)

    // Ridge validation error should be lower than linear regression.
    assert(ridgeErr < linearErr,
      "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
  }

  test("model save/load") {
    val model = RidgeRegressionSuite.model

    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString

    // Save model, load it back, and compare.
    try {
      model.save(sc, path)
      val sameModel = RidgeRegressionModel.load(sc, path)
      assert(model.weights == sameModel.weights)
      assert(model.intercept == sameModel.intercept)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
}

class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {

  test("task size should be small in both training and prediction") {
    val m = 4
    val n = 200000
    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
      val random = new Random(idx)
      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
    }.cache()
    // If we serialize data directly in the task closure, the size of the serialized task would be
    // greater than 1MB and hence Spark would throw an error.
    val model = RidgeRegressionWithSGD.train(points, 2)
    val predictions = model.predict(points.map(_.features))
  }
} 
Example 20
Source File: FlumeInputDStream.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.Executors

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.avro.ipc.NettyServer
import org.apache.avro.ipc.specific.SpecificResponder
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol, Status}
import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels}
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
import org.jboss.netty.handler.codec.compression._

import org.apache.spark.internal.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.Utils

private[streaming]
class FlumeInputDStream[T: ClassTag](
  _ssc: StreamingContext,
  host: String,
  port: Int,
  storageLevel: StorageLevel,
  enableDecompression: Boolean
) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) {

  override def getReceiver(): Receiver[SparkFlumeEvent] = {
    new FlumeReceiver(host, port, storageLevel, enableDecompression)
  }
}


  private[streaming]
  class CompressionChannelPipelineFactory extends ChannelPipelineFactory {
    def getPipeline(): ChannelPipeline = {
      val pipeline = Channels.pipeline()
      val encoder = new ZlibEncoder(6)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      pipeline
    }
  }
} 
Example 21
Source File: FlumeTestUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.net.{InetSocketAddress, ServerSocket}
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.{List => JList}
import java.util.Collections

import scala.collection.JavaConverters._

import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.commons.lang3.RandomUtils
import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol}
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder}

import org.apache.spark.util.Utils
import org.apache.spark.SparkConf


  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }

} 
Example 22
Source File: EventTransformer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.io.{ObjectInput, ObjectOutput}

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[streaming] object EventTransformer extends Logging {
  def readExternal(in: ObjectInput): (java.util.HashMap[CharSequence, CharSequence],
    Array[Byte]) = {
    val bodyLength = in.readInt()
    val bodyBuff = new Array[Byte](bodyLength)
    in.readFully(bodyBuff)

    val numHeaders = in.readInt()
    val headers = new java.util.HashMap[CharSequence, CharSequence]

    for (i <- 0 until numHeaders) {
      val keyLength = in.readInt()
      val keyBuff = new Array[Byte](keyLength)
      in.readFully(keyBuff)
      val key: String = Utils.deserialize(keyBuff)

      val valLength = in.readInt()
      val valBuff = new Array[Byte](valLength)
      in.readFully(valBuff)
      val value: String = Utils.deserialize(valBuff)

      headers.put(key, value)
    }
    (headers, bodyBuff)
  }

  def writeExternal(out: ObjectOutput, headers: java.util.Map[CharSequence, CharSequence],
    body: Array[Byte]) {
    out.writeInt(body.length)
    out.write(body)
    val numHeaders = headers.size()
    out.writeInt(numHeaders)
    for ((k, v) <- headers.asScala) {
      val keyBuff = Utils.serialize(k.toString)
      out.writeInt(keyBuff.length)
      out.write(keyBuff)
      val valBuff = Utils.serialize(v.toString)
      out.writeInt(valBuff.length)
      out.write(valBuff)
    }
  }
} 
Example 23
Source File: TestOutputStream.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.{IOException, ObjectInputStream}
import java.util.concurrent.ConcurrentLinkedQueue

import scala.reflect.ClassTag

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.{DStream, ForEachDStream}
import org.apache.spark.util.Utils


class TestOutputStream[T: ClassTag](parent: DStream[T],
    val output: ConcurrentLinkedQueue[Seq[T]] = new ConcurrentLinkedQueue[Seq[T]]())
  extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
    val collected = rdd.collect()
    output.add(collected)
  }, false) {

  // This is to clear the output buffer every it is read from a checkpoint
  @throws(classOf[IOException])
  private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
    ois.defaultReadObject()
    output.clear()
  }
} 
Example 24
Source File: SparkSQLEnv.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.PrintStream

import scala.collection.JavaConverters._

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils}
import org.apache.spark.util.Utils


  def stop() {
    logDebug("Shutting down Spark SQL Environment")
    // Stop the SparkContext
    if (SparkSQLEnv.sparkContext != null) {
      sparkContext.stop()
      sparkContext = null
      sqlContext = null
    }
  }
} 
Example 25
Source File: JdbcConnectionUriSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.sql.DriverManager

import org.apache.hive.jdbc.HiveDriver

import org.apache.spark.util.Utils

class JdbcConnectionUriSuite extends HiveThriftServer2Test {
  Utils.classForName(classOf[HiveDriver].getCanonicalName)

  override def mode: ServerMode.Value = ServerMode.binary

  val JDBC_TEST_DATABASE = "jdbc_test_database"
  val USER = System.getProperty("user.name")
  val PASSWORD = ""

  override protected def beforeAll(): Unit = {
    super.beforeAll()

    val jdbcUri = s"jdbc:hive2://localhost:$serverPort/"
    val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD)
    val statement = connection.createStatement()
    statement.execute(s"CREATE DATABASE $JDBC_TEST_DATABASE")
    connection.close()
  }

  override protected def afterAll(): Unit = {
    try {
      val jdbcUri = s"jdbc:hive2://localhost:$serverPort/"
      val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD)
      val statement = connection.createStatement()
      statement.execute(s"DROP DATABASE $JDBC_TEST_DATABASE")
      connection.close()
    } finally {
      super.afterAll()
    }
  }

  test("SPARK-17819 Support default database in connection URIs") {
    val jdbcUri = s"jdbc:hive2://localhost:$serverPort/$JDBC_TEST_DATABASE"
    val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD)
    val statement = connection.createStatement()
    try {
      val resultSet = statement.executeQuery("select current_database()")
      resultSet.next()
      assert(resultSet.getString(1) === JDBC_TEST_DATABASE)
    } finally {
      statement.close()
      connection.close()
    }
  }
} 
Example 26
Source File: UDTRegistration.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.types

import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


  def getUDTFor(userClass: String): Option[Class[_]] = {
    udtMap.get(userClass).map { udtClassName =>
      if (Utils.classIsLoadable(udtClassName)) {
        val udtClass = Utils.classForName(udtClassName)
        if (classOf[UserDefinedType[_]].isAssignableFrom(udtClass)) {
          udtClass
        } else {
          throw new SparkException(
            s"${udtClass.getName} is not an UserDefinedType. Please make sure registering " +
              s"an UserDefinedType for ${userClass}")
        }
      } else {
        throw new SparkException(
          s"Can not load in UserDefinedType ${udtClassName} for user class ${userClass}.")
      }
    }
  }
} 
Example 27
Source File: randomExpressions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom


@ExpressionDescription(
  usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.")
case class Randn(seed: Long) extends RDG {
  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

  def this() = this(Utils.random.nextLong())

  def this(seed: Expression) = this(seed match {
    case IntegerLiteral(s) => s
    case _ => throw new AnalysisException("Input argument to randn must be an integer literal.")
  })

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val rngTerm = ctx.freshName("rng")
    val className = classOf[XORShiftRandom].getName
    ctx.addMutableState(className, rngTerm,
      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
  }
} 
Example 28
Source File: package.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.rules
import org.apache.spark.util.Utils


  object DumpByteCode {
    import scala.sys.process._
    val dumpDirectory = Utils.createTempDir()
    dumpDirectory.mkdir()

    def apply(obj: Any): Unit = {
      val generatedClass = obj.getClass
      val classLoader =
        generatedClass
          .getClassLoader
          .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader]
      val generatedBytes = classLoader.classBytes(generatedClass.getName)

      val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName)
      if (!packageDir.exists()) { packageDir.mkdir() }

      val classFile =
        new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class")

      val outfile = new java.io.FileOutputStream(classFile)
      outfile.write(generatedBytes)
      outfile.close()

      // scalastyle:off println
      println(
        s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!)
      // scalastyle:on println
    }
  }
} 
Example 29
Source File: OuterScopes.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.encoders

import java.util.concurrent.ConcurrentMap

import com.google.common.collect.MapMaker

import org.apache.spark.util.Utils

object OuterScopes {
  @transient
  lazy val outerScopes: ConcurrentMap[String, AnyRef] =
    new MapMaker().weakValues().makeMap()

  
  def getOuterScope(innerCls: Class[_]): () => AnyRef = {
    assert(innerCls.isMemberClass)
    val outerClassName = innerCls.getDeclaringClass.getName
    val outer = outerScopes.get(outerClassName)
    if (outer == null) {
      outerClassName match {
        // If the outer class is generated by REPL, users don't need to register it as it has
        // only one instance and there is a way to retrieve it: get the `$read` object, call the
        // `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()`
        // method multiply times to get the single instance of the inner most `$iw` class.
        case REPLClass(baseClassName) =>
          () => {
            val objClass = Utils.classForName(baseClassName + "$")
            val objInstance = objClass.getField("MODULE$").get(null)
            val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance)
            val baseClass = Utils.classForName(baseClassName)

            var getter = iwGetter(baseClass)
            var obj = baseInstance
            while (getter != null) {
              obj = getter.invoke(obj)
              getter = iwGetter(getter.getReturnType)
            }

            if (obj == null) {
              throw new RuntimeException(s"Failed to get outer pointer for ${innerCls.getName}")
            }

            outerScopes.putIfAbsent(outerClassName, obj)
            obj
          }
        case _ => null
      }
    } else {
      () => outer
    }
  }

  private def iwGetter(cls: Class[_]) = {
    try {
      cls.getMethod("$iw")
    } catch {
      case _: NoSuchMethodException => null
    }
  }

  // The format of REPL generated wrapper class's name, e.g. `$line12.$read$$iw$$iw`
  private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r
} 
Example 30
Source File: RuleExecutor.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.rules

import scala.collection.JavaConverters._

import com.google.common.util.concurrent.AtomicLongMap

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.util.Utils

object RuleExecutor {
  protected val timeMap = AtomicLongMap.create[String]()

  
  def execute(plan: TreeType): TreeType = {
    var curPlan = plan

    batches.foreach { batch =>
      val batchStartPlan = curPlan
      var iteration = 1
      var lastPlan = curPlan
      var continue = true

      // Run until fix point (or the max number of iterations as specified in the strategy.
      while (continue) {
        curPlan = batch.rules.foldLeft(curPlan) {
          case (plan, rule) =>
            val startTime = System.nanoTime()
            val result = rule(plan)
            val runTime = System.nanoTime() - startTime
            RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime)

            if (!result.fastEquals(plan)) {
              logTrace(
                s"""
                  |=== Applying Rule ${rule.ruleName} ===
                  |${sideBySide(plan.treeString, result.treeString).mkString("\n")}
                """.stripMargin)
            }

            result
        }
        iteration += 1
        if (iteration > batch.strategy.maxIterations) {
          // Only log if this is a rule that is supposed to run more than once.
          if (iteration != 2) {
            val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}"
            if (Utils.isTesting) {
              throw new TreeNodeException(curPlan, message, null)
            } else {
              logWarning(message)
            }
          }
          continue = false
        }

        if (curPlan.fastEquals(lastPlan)) {
          logTrace(
            s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.")
          continue = false
        }
        lastPlan = curPlan
      }

      if (!batchStartPlan.fastEquals(curPlan)) {
        logDebug(
          s"""
          |=== Result of Batch ${batch.name} ===
          |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")}
        """.stripMargin)
      } else {
        logTrace(s"Batch ${batch.name} has no effect.")
      }
    }

    curPlan
  }
} 
Example 31
Source File: CompressionCodecs.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress._

import org.apache.spark.util.Utils

object CompressionCodecs {
  private val shortCompressionCodecNames = Map(
    "none" -> null,
    "uncompressed" -> null,
    "bzip2" -> classOf[BZip2Codec].getName,
    "deflate" -> classOf[DeflateCodec].getName,
    "gzip" -> classOf[GzipCodec].getName,
    "lz4" -> classOf[Lz4Codec].getName,
    "snappy" -> classOf[SnappyCodec].getName)

  
  def setCodecConfiguration(conf: Configuration, codec: String): Unit = {
    if (codec != null) {
      conf.set("mapreduce.output.fileoutputformat.compress", "true")
      conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
      conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
      conf.set("mapreduce.map.output.compress", "true")
      conf.set("mapreduce.map.output.compress.codec", codec)
    } else {
      // This infers the option `compression` is set to `uncompressed` or `none`.
      conf.set("mapreduce.output.fileoutputformat.compress", "false")
      conf.set("mapreduce.map.output.compress", "false")
    }
  }
} 
Example 32
Source File: LogicalRelation.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils


  override def newInstance(): this.type = {
    LogicalRelation(
      relation,
      expectedOutputAttributes.map(_.map(_.newInstance())),
      catalogTable).asInstanceOf[this.type]
  }

  override def refresh(): Unit = relation match {
    case fs: HadoopFsRelation => fs.location.refresh()
    case _ =>  // Do nothing.
  }

  override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
} 
Example 33
Source File: DriverRegistry.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Driver, DriverManager}

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


object DriverRegistry extends Logging {

  private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty

  def register(className: String): Unit = {
    val cls = Utils.getContextOrSparkClassLoader.loadClass(className)
    if (cls.getClassLoader == null) {
      logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required")
    } else if (wrapperMap.get(className).isDefined) {
      logTrace(s"Wrapper for $className already exists")
    } else {
      synchronized {
        if (wrapperMap.get(className).isEmpty) {
          val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver])
          DriverManager.registerDriver(wrapper)
          wrapperMap(className) = wrapper
          logTrace(s"Wrapper for $className registered")
        }
      }
    }
  }
} 
Example 34
Source File: SparkPlanInfo.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.SQLMetricInfo
import org.apache.spark.util.Utils


@DeveloperApi
class SparkPlanInfo(
    val nodeName: String,
    val simpleString: String,
    val children: Seq[SparkPlanInfo],
    val metadata: Map[String, String],
    val metrics: Seq[SQLMetricInfo]) {

  override def hashCode(): Int = {
    // hashCode of simpleString should be good enough to distinguish the plans from each other
    // within a plan
    simpleString.hashCode
  }

  override def equals(other: Any): Boolean = other match {
    case o: SparkPlanInfo =>
      nodeName == o.nodeName && simpleString == o.simpleString && children == o.children
    case _ => false
  }
}

private[execution] object SparkPlanInfo {

  def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
    val children = plan match {
      case ReusedExchangeExec(_, child) => child :: Nil
      case _ => plan.children ++ plan.subqueries
    }
    val metrics = plan.metrics.toSeq.map { case (key, metric) =>
      new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
    }

    new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan),
      plan.metadata, metrics)
  }
} 
Example 35
Source File: SQLMetrics.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.metric

import java.text.NumberFormat
import java.util.Locale

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}


class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] {
  // This is a workaround for SPARK-11013.
  // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
  // update it at the end of task and the value will be at least 0. Then we can filter out the -1
  // values before calculate max, min, etc.
  private[this] var _value = initValue
  private var _zeroValue = initValue

  override def copy(): SQLMetric = {
    val newAcc = new SQLMetric(metricType, _value)
    newAcc._zeroValue = initValue
    newAcc
  }

  override def reset(): Unit = _value = _zeroValue

  override def merge(other: AccumulatorV2[Long, Long]): Unit = other match {
    case o: SQLMetric => _value += o.value
    case _ => throw new UnsupportedOperationException(
      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
  }

  override def isZero(): Boolean = _value == _zeroValue

  override def add(v: Long): Unit = _value += v

  def +=(v: Long): Unit = _value += v

  override def value: Long = _value

  // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
  override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
    new AccumulableInfo(
      id, name, update, value, true, true, Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
  }
}


object SQLMetrics {
  private val SUM_METRIC = "sum"
  private val SIZE_METRIC = "size"
  private val TIMING_METRIC = "timing"

  def createMetric(sc: SparkContext, name: String): SQLMetric = {
    val acc = new SQLMetric(SUM_METRIC)
    acc.register(sc, name = Some(name), countFailedValues = false)
    acc
  }

  
  def stringValue(metricsType: String, values: Seq[Long]): String = {
    if (metricsType == SUM_METRIC) {
      val numberFormat = NumberFormat.getIntegerInstance(Locale.ENGLISH)
      numberFormat.format(values.sum)
    } else {
      val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
        Utils.bytesToString
      } else if (metricsType == TIMING_METRIC) {
        Utils.msDurationToString
      } else {
        throw new IllegalStateException("unexpected metrics type: " + metricsType)
      }

      val validValues = values.filter(_ >= 0)
      val Seq(sum, min, med, max) = {
        val metric = if (validValues.isEmpty) {
          Seq.fill(4)(0L)
        } else {
          val sorted = validValues.sorted
          Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
        }
        metric.map(strFormat)
      }
      s"\n$sum ($min, $med, $max)"
    }
  }
} 
Example 36
Source File: ExistingRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Encoder, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

object RDDConversions {
  def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
    data.mapPartitions { iterator =>
      val numColumns = outputTypes.length
      val mutableRow = new GenericInternalRow(numColumns)
      val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
      iterator.map { r =>
        var i = 0
        while (i < numColumns) {
          mutableRow(i) = converters(i)(r.productElement(i))
          i += 1
        }

        mutableRow
      }
    }
  }

  
case class RDDScanExec(
    output: Seq[Attribute],
    rdd: RDD[InternalRow],
    override val nodeName: String) extends LeafExecNode {

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    rdd.mapPartitionsInternal { iter =>
      val proj = UnsafeProjection.create(schema)
      iter.map { r =>
        numOutputRows += 1
        proj(r)
      }
    }
  }

  override def simpleString: String = {
    s"Scan $nodeName${Utils.truncatedString(output, "[", ",", "]")}"
  }
} 
Example 37
Source File: RowDataSourceStrategySuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE TEMPORARY TABLE inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
      """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 38
Source File: PartitionedWriteSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("write many partitions") {
    val path = Utils.createTempDir()
    path.delete()

    val df = spark.range(100).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("write many partitions with repeats") {
    val path = Utils.createTempDir()
    path.delete()

    val base = spark.range(100)
    val df = base.union(base).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("partitioned columns should appear at the end of schema") {
    withTempPath { f =>
      val path = f.getAbsolutePath
      Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
      assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
    }
  }
} 
Example 39
Source File: YarnRMClient.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import java.util.{List => JList}

import scala.collection.JavaConverters._
import scala.util.Try

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.webapp.util.WebAppUtils

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils


  def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
    val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt)
    val yarnMaxAttempts = yarnConf.getInt(
      YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
    val retval: Int = sparkMaxAttempts match {
      case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
      case None => yarnMaxAttempts
    }

    retval
  }

} 
Example 40
Source File: YarnClusterSchedulerBackend.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.conf.YarnConfiguration

import org.apache.spark.SparkContext
import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils

private[spark] class YarnClusterSchedulerBackend(
    scheduler: TaskSchedulerImpl,
    sc: SparkContext)
  extends YarnSchedulerBackend(scheduler, sc) {

  override def start() {
    val attemptId = ApplicationMaster.getAttemptId
    bindToYarn(attemptId.getApplicationId(), Some(attemptId))
    super.start()
    totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf)
  }

  override def getDriverLogUrls: Option[Map[String, String]] = {
    var driverLogs: Option[Map[String, String]] = None
    try {
      val yarnConf = new YarnConfiguration(sc.hadoopConfiguration)
      val containerId = YarnSparkHadoopUtil.get.getContainerId

      val httpAddress = System.getenv(Environment.NM_HOST.name()) +
        ":" + System.getenv(Environment.NM_HTTP_PORT.name())
      // lookup appropriate http scheme for container log urls
      val yarnHttpPolicy = yarnConf.get(
        YarnConfiguration.YARN_HTTP_POLICY_KEY,
        YarnConfiguration.YARN_HTTP_POLICY_DEFAULT
      )
      val user = Utils.getCurrentUserName()
      val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
      val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user"
      logDebug(s"Base URL for logs: $baseUrl")
      driverLogs = Some(Map(
        "stderr" -> s"$baseUrl/stderr?start=-4096",
        "stdout" -> s"$baseUrl/stdout?start=-4096"))
    } catch {
      case e: Exception =>
        logInfo("Error while building AM log links, so AM" +
          " logs link will not appear in application UI", e)
    }
    driverLogs
  }
} 
Example 41
Source File: YarnScheduler.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.apache.hadoop.yarn.util.RackResolver
import org.apache.log4j.{Level, Logger}

import org.apache.spark._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils

private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) {

  // RackResolver logs an INFO message whenever it resolves a rack, which is way too often.
  if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
    Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
  }

  // By default, rack is unknown
  override def getRackForHost(hostPort: String): Option[String] = {
    val host = Utils.parseHostPort(hostPort)._1
    Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation)
  }
} 
Example 42
Source File: SchedulerExtensionService.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import java.util.concurrent.atomic.AtomicBoolean

import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}

import org.apache.spark.SparkContext
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


  override def stop(): Unit = {
    if (started.getAndSet(false)) {
      logInfo(s"Stopping $this")
      services.foreach { s =>
        Utils.tryLogNonFatalError(s.stop())
      }
    }
  }

  override def toString(): String = s"""SchedulerExtensionServices
    |(serviceOption=$serviceOption,
    | services=$services,
    | started=$started)""".stripMargin
} 
Example 43
Source File: DStreamCheckpointData.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.dstream

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable.HashMap
import scala.reflect.ClassTag

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.Logging
import org.apache.spark.streaming.Time
import org.apache.spark.util.Utils

private[streaming]
class DStreamCheckpointData[T: ClassTag](dstream: DStream[T])
  extends Serializable with Logging {
  protected val data = new HashMap[Time, AnyRef]()

  // Mapping of the batch time to the checkpointed RDD file of that time
  @transient private var timeToCheckpointFile = new HashMap[Time, String]
  // Mapping of the batch time to the time of the oldest checkpointed RDD
  // in that batch's checkpoint data
  @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time]

  @transient private var fileSystem: FileSystem = null
  protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]]

  
  def restore() {
    // Create RDDs from the checkpoint data
    currentCheckpointFiles.foreach {
      case(time, file) =>
        logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'")
        dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file)))
    }
  }

  override def toString: String = {
    "[\n" + currentCheckpointFiles.size + " checkpoint files \n" +
      currentCheckpointFiles.mkString("\n") + "\n]"
  }

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    logDebug(this.getClass().getSimpleName + ".writeObject used")
    if (dstream.context.graph != null) {
      dstream.context.graph.synchronized {
        if (dstream.context.graph.checkpointInProgress) {
          oos.defaultWriteObject()
        } else {
          val msg = "Object of " + this.getClass.getName + " is being serialized " +
            " possibly as a part of closure of an RDD operation. This is because " +
            " the DStream object is being referred to from within the closure. " +
            " Please rewrite the RDD operation inside this DStream to avoid this. " +
            " This has been enforced to avoid bloating of Spark tasks " +
            " with unnecessary objects."
          throw new java.io.NotSerializableException(msg)
        }
      }
    } else {
      throw new java.io.NotSerializableException(
        "Graph is unexpectedly null when DStream is being serialized.")
    }
  }

  @throws(classOf[IOException])
  private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
    logDebug(this.getClass().getSimpleName + ".readObject used")
    ois.defaultReadObject()
    timeToOldestCheckpointFileTime = new HashMap[Time, Time]
    timeToCheckpointFile = new HashMap[Time, String]
  }
} 
Example 44
Source File: Job.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import scala.util.{Failure, Try}

import org.apache.spark.streaming.Time
import org.apache.spark.util.{CallSite, Utils}


  def outputOpId: Int = {
    if (!isSet) {
      throw new IllegalStateException("Cannot access number before calling setId")
    }
    _outputOpId
  }

  def setOutputOpId(outputOpId: Int) {
    if (isSet) {
      throw new IllegalStateException("Cannot call setOutputOpId more than once")
    }
    isSet = true
    _id = s"streaming job $time.$outputOpId"
    _outputOpId = outputOpId
  }

  def setCallSite(callSite: CallSite): Unit = {
    _callSite = callSite
  }

  def callSite: CallSite = _callSite

  def setStartTime(startTime: Long): Unit = {
    _startTime = Some(startTime)
  }

  def setEndTime(endTime: Long): Unit = {
    _endTime = Some(endTime)
  }

  def toOutputOperationInfo: OutputOperationInfo = {
    val failureReason = if (_result != null && _result.isFailure) {
      Some(Utils.exceptionString(_result.asInstanceOf[Failure[_]].exception))
    } else {
      None
    }
    OutputOperationInfo(
      time, outputOpId, callSite.shortForm, callSite.longForm, _startTime, _endTime, failureReason)
  }

  override def toString: String = id
} 
Example 45
Source File: FileBasedWriteAheadLogWriter.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.util

import java.io._
import java.nio.ByteBuffer

import org.apache.hadoop.conf.Configuration

import org.apache.spark.util.Utils


  def write(data: ByteBuffer): FileBasedWriteAheadLogSegment = synchronized {
    assertOpen()
    data.rewind() // Rewind to ensure all data in the buffer is retrieved
    val lengthToWrite = data.remaining()
    val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite)
    stream.writeInt(lengthToWrite)
    Utils.writeByteBuffer(data, stream: OutputStream)
    flush()
    nextOffset = stream.getPos()
    segment
  }

  override def close(): Unit = synchronized {
    closed = true
    stream.close()
  }

  private def flush() {
    stream.hflush()
    // Useful for local file system where hflush/sync does not work (HADOOP-7844)
    stream.getWrappedStream.flush()
  }

  private def assertOpen() {
    HdfsUtils.checkState(!closed, "Stream is closed. Create a new Writer to write to file.")
  }
} 
Example 46
Source File: FailureSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  private val batchDuration: Duration = Milliseconds(1000)
  private val numBatches = 30
  private var directory: File = null

  before {
    directory = Utils.createTempDir()
  }

  after {
    if (directory != null) {
      Utils.deleteRecursively(directory)
    }
    StreamingContext.getActive().foreach { _.stop() }

    // Stop SparkContext if active
    SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop()
  }

  test("multiple failures with map") {
    MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration)
  }

  test("multiple failures with updateStateByKey") {
    MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration)
  }
} 
Example 47
Source File: SerializableWritable.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.io._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.Writable

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

@DeveloperApi
class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {

  def value: T = t

  override def toString: String = t.toString

  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
    out.defaultWriteObject()
    new ObjectWritable(t).write(out)
  }

  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    val ow = new ObjectWritable()
    ow.setConf(new Configuration(false))
    ow.readFields(in)
    t = ow.get().asInstanceOf[T]
  }
} 
Example 48
Source File: ShellBasedGroupsMappingProvider.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.security

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils



private[spark] class ShellBasedGroupsMappingProvider extends GroupMappingServiceProvider
  with Logging {

  override def getGroups(username: String): Set[String] = {
    val userGroups = getUnixGroups(username)
    logDebug("User: " + username + " Groups: " + userGroups.mkString(","))
    userGroups
  }

  // shells out a "bash -c id -Gn username" to get user groups
  private def getUnixGroups(username: String): Set[String] = {
    val cmdSeq = Seq("bash", "-c", "id -Gn " + username)
    // we need to get rid of the trailing "\n" from the result of command execution
    Utils.executeAndGetOutput(cmdSeq).stripLineEnd.split(" ").toSet
  }
} 
Example 49
Source File: NettyStreamManager.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rpc.netty

import java.io.File
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.server.StreamManager
import org.apache.spark.rpc.RpcEnvFileServer
import org.apache.spark.util.Utils


private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
  extends StreamManager with RpcEnvFileServer {

  private val files = new ConcurrentHashMap[String, File]()
  private val jars = new ConcurrentHashMap[String, File]()
  private val dirs = new ConcurrentHashMap[String, File]()

  override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
    throw new UnsupportedOperationException()
  }

  override def openStream(streamId: String): ManagedBuffer = {
    val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)
    val file = ftype match {
      case "files" => files.get(fname)
      case "jars" => jars.get(fname)
      case other =>
        val dir = dirs.get(ftype)
        require(dir != null, s"Invalid stream URI: $ftype not found.")
        new File(dir, fname)
    }

    if (file != null && file.isFile()) {
      new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
    } else {
      null
    }
  }

  override def addFile(file: File): String = {
    val existingPath = files.putIfAbsent(file.getName, file)
    require(existingPath == null || existingPath == file,
      s"File ${file.getName} was already registered with a different path " +
        s"(old path = $existingPath, new path = $file")
    s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}"
  }

  override def addJar(file: File): String = {
    val existingPath = jars.putIfAbsent(file.getName, file)
    require(existingPath == null || existingPath == file,
      s"File ${file.getName} was already registered with a different path " +
        s"(old path = $existingPath, new path = $file")
    s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}"
  }

  override def addDirectory(baseUri: String, path: File): String = {
    val fixedBaseUri = validateDirectoryUri(baseUri)
    require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null,
      s"URI '$fixedBaseUri' already registered.")
    s"${rpcEnv.address.toSparkURL}$fixedBaseUri"
  }

} 
Example 50
Source File: RpcTimeout.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rpc

import java.util.concurrent.TimeoutException

import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.control.NonFatal

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.util.Utils


  def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = {
    require(timeoutPropList.nonEmpty)

    // Find the first set property or use the default value with the first property
    val itr = timeoutPropList.iterator
    var foundProp: Option[(String, String)] = None
    while (itr.hasNext && foundProp.isEmpty) {
      val propKey = itr.next()
      conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
    }
    val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
    val timeout = { Utils.timeStringAsSeconds(finalProp._2).seconds }
    new RpcTimeout(timeout, finalProp._1)
  }
} 
Example 51
Source File: SortShuffleWriter.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.sort

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
    shuffleBlockResolver: IndexShuffleBlockResolver,
    handle: BaseShuffleHandle[K, V, C],
    mapId: Int,
    context: TaskContext)
  extends ShuffleWriter[K, V] with Logging {

  private val dep = handle.dependency

  private val blockManager = SparkEnv.get.blockManager

  private var sorter: ExternalSorter[K, V, _] = null

  // Are we in the process of stopping? Because map tasks can call stop() with success = true
  // and then call stop() with success = false if they get an exception, we want to make sure
  // we don't try deleting files, etc twice.
  private var stopping = false

  private var mapStatus: MapStatus = null

  private val writeMetrics = context.taskMetrics().shuffleWriteMetrics

  
  override def stop(success: Boolean): Option[MapStatus] = {
    try {
      if (stopping) {
        return None
      }
      stopping = true
      if (success) {
        return Option(mapStatus)
      } else {
        return None
      }
    } finally {
      // Clean up our sorter, which may have its own intermediate files
      if (sorter != null) {
        val startTime = System.nanoTime()
        sorter.stop()
        writeMetrics.incWriteTime(System.nanoTime - startTime)
        sorter = null
      }
    }
  }
}

private[spark] object SortShuffleWriter {
  def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
    // We cannot bypass sorting if we need to do map-side aggregation.
    if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      false
    } else {
      val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
      dep.partitioner.numPartitions <= bypassMergeThreshold
    }
  }
} 
Example 52
Source File: MetricsConfig.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.metrics

import java.io.{FileInputStream, InputStream}
import java.util.Properties

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.matching.Regex

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

private[spark] class MetricsConfig(conf: SparkConf) extends Logging {

  private val DEFAULT_PREFIX = "*"
  private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r
  private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties"

  private[metrics] val properties = new Properties()
  private[metrics] var perInstanceSubProperties: mutable.HashMap[String, Properties] = null

  private def setDefaultProperties(prop: Properties) {
    prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet")
    prop.setProperty("*.sink.servlet.path", "/metrics/json")
    prop.setProperty("master.sink.servlet.path", "/metrics/master/json")
    prop.setProperty("applications.sink.servlet.path", "/metrics/applications/json")
  }

  
  private[this] def loadPropertiesFromFile(path: Option[String]): Unit = {
    var is: InputStream = null
    try {
      is = path match {
        case Some(f) => new FileInputStream(f)
        case None => Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)
      }

      if (is != null) {
        properties.load(is)
      }
    } catch {
      case e: Exception =>
        val file = path.getOrElse(DEFAULT_METRICS_CONF_FILENAME)
        logError(s"Error loading configuration file $file", e)
    } finally {
      if (is != null) {
        is.close()
      }
    }
  }

} 
Example 53
Source File: PythonGatewayServer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  initializeLogIfNecessary(true)

  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 54
Source File: PythonPartitioner.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import org.apache.spark.Partitioner
import org.apache.spark.util.Utils



private[spark] class PythonPartitioner(
  override val numPartitions: Int,
  val pyPartitionFunctionId: Long)
  extends Partitioner {

  override def getPartition(key: Any): Int = key match {
    case null => 0
    // we don't trust the Python partition function to return valid partition ID's so
    // let's do a modulo numPartitions in any case
    case key: Long => Utils.nonNegativeMod(key.toInt, numPartitions)
    case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions)
  }

  override def equals(other: Any): Boolean = other match {
    case h: PythonPartitioner =>
      h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
    case _ =>
      false
  }

  override def hashCode: Int = 31 * numPartitions + pyPartitionFunctionId.hashCode
} 
Example 55
Source File: ExternalShuffleService.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import java.util.concurrent.CountDownLatch

import scala.collection.JavaConverters._

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslServerBootstrap
import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap}
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.{ShutdownHookManager, Utils}


  private[spark] def main(
      args: Array[String],
      newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = {
    Utils.initDaemon(log)
    val sparkConf = new SparkConf
    Utils.loadDefaultSparkProperties(sparkConf)
    val securityManager = new SecurityManager(sparkConf)

    // we override this value since this service is started from the command line
    // and we assume the user really wants it to be running
    sparkConf.set("spark.shuffle.service.enabled", "true")
    server = newShuffleService(sparkConf, securityManager)
    server.start()

    logDebug("Adding shutdown hook") // force eager creation of logger
    ShutdownHookManager.addShutdownHook { () =>
      logInfo("Shutting down shuffle service.")
      server.stop()
      barrier.countDown()
    }

    // keep running until the process is terminated
    barrier.await()
  }
} 
Example 56
Source File: FileSystemPersistenceEngine.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import java.io._

import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
import org.apache.spark.util.Utils



private[master] class FileSystemPersistenceEngine(
    val dir: String,
    val serializer: Serializer)
  extends PersistenceEngine with Logging {

  new File(dir).mkdir()

  override def persist(name: String, obj: Object): Unit = {
    serializeIntoFile(new File(dir + File.separator + name), obj)
  }

  override def unpersist(name: String): Unit = {
    val f = new File(dir + File.separator + name)
    if (!f.delete()) {
      logWarning(s"Error deleting ${f.getPath()}")
    }
  }

  override def read[T: ClassTag](prefix: String): Seq[T] = {
    val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix))
    files.map(deserializeFromFile[T])
  }

  private def serializeIntoFile(file: File, value: AnyRef) {
    val created = file.createNewFile()
    if (!created) { throw new IllegalStateException("Could not create file: " + file) }
    val fileOut = new FileOutputStream(file)
    var out: SerializationStream = null
    Utils.tryWithSafeFinally {
      out = serializer.newInstance().serializeStream(fileOut)
      out.writeObject(value)
    } {
      fileOut.close()
      if (out != null) {
        out.close()
      }
    }
  }

  private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
    val fileIn = new FileInputStream(file)
    var in: DeserializationStream = null
    try {
      in = serializer.newInstance().deserializeStream(fileIn)
      in.readObject[T]()
    } finally {
      fileIn.close()
      if (in != null) {
        in.close()
      }
    }
  }

} 
Example 57
Source File: DriverInfo.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import java.util.Date

import org.apache.spark.deploy.DriverDescription
import org.apache.spark.util.Utils

private[deploy] class DriverInfo(
    val startTime: Long,
    val id: String,
    val desc: DriverDescription,
    val submitDate: Date)
  extends Serializable {

  @transient var state: DriverState.Value = DriverState.SUBMITTED
  
  @transient var worker: Option[WorkerInfo] = None

  init()

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init(): Unit = {
    state = DriverState.SUBMITTED
    worker = None
    exception = None
  }
} 
Example 58
Source File: MasterArguments.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.annotation.tailrec

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.{IntParam, Utils}


  private def printUsageAndExit(exitCode: Int) {
    // scalastyle:off println
    System.err.println(
      "Usage: Master [options]\n" +
      "\n" +
      "Options:\n" +
      "  -i HOST, --ip HOST     Hostname to listen on (deprecated, please use --host or -h) \n" +
      "  -h HOST, --host HOST   Hostname to listen on\n" +
      "  -p PORT, --port PORT   Port to listen on (default: 7077)\n" +
      "  --webui-port PORT      Port for web UI (default: 8080)\n" +
      "  --properties-file FILE Path to a custom Spark properties file.\n" +
      "                         Default is conf/spark-defaults.conf.")
    // scalastyle:on println
    System.exit(exitCode)
  }
} 
Example 59
Source File: WorkerInfo.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String,
    val host: String,
    val port: Int,
    val cores: Int,
    val memory: Int,
    val endpoint: RpcEndpointRef,
    val webUiAddress: String)
  extends Serializable {

  Utils.checkHost(host, "Expected hostname")
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()

  def coresFree: Int = cores - coresUsed
  def memoryFree: Int = memory - memoryUsed

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 60
Source File: ClientArguments.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import java.net.{URI, URISyntaxException}

import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer

import org.apache.log4j.Level

import org.apache.spark.util.{IntParam, MemoryParam, Utils}


  private def printUsageAndExit(exitCode: Int) {
    // TODO: It wouldn't be too hard to allow users to submit their app and dependency jars
    //       separately similar to in the YARN client.
    val usage =
     s"""
      |Usage: DriverClient [options] launch <active-master> <jar-url> <main-class> [driver options]
      |Usage: DriverClient kill <active-master> <driver-id>
      |
      |Options:
      |   -c CORES, --cores CORES        Number of cores to request (default: $DEFAULT_CORES)
      |   -m MEMORY, --memory MEMORY     Megabytes of memory to request (default: $DEFAULT_MEMORY)
      |   -s, --supervise                Whether to restart the driver on failure
      |                                  (default: $DEFAULT_SUPERVISE)
      |   -v, --verbose                  Print more debugging output
     """.stripMargin
    // scalastyle:off println
    System.err.println(usage)
    // scalastyle:on println
    System.exit(exitCode)
  }
}

private[deploy] object ClientArguments {
  val DEFAULT_CORES = 1
  val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB
  val DEFAULT_SUPERVISE = false

  def isValidJarUrl(s: String): Boolean = {
    try {
      val uri = new URI(s)
      uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar")
    } catch {
      case _: URISyntaxException => false
    }
  }
} 
Example 61
Source File: DriverWrapper.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import java.io.File

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}


      case workerUrl :: userJar :: mainClass :: extraArgs =>
        val conf = new SparkConf()
        val rpcEnv = RpcEnv.create("Driver",
          Utils.localHostName(), 0, conf, new SecurityManager(conf))
        rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl))

        val currentLoader = Thread.currentThread.getContextClassLoader
        val userJarUrl = new File(userJar).toURI().toURL()
        val loader =
          if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
            new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader)
          } else {
            new MutableURLClassLoader(Array(userJarUrl), currentLoader)
          }
        Thread.currentThread.setContextClassLoader(loader)

        // Delegate to supplied main class
        val clazz = Utils.classForName(mainClass)
        val mainMethod = clazz.getMethod("main", classOf[Array[String]])
        mainMethod.invoke(null, extraArgs.toArray[String])

        rpcEnv.shutdown()

      case _ =>
        // scalastyle:off println
        System.err.println("Usage: DriverWrapper <workerUrl> <userJar> <driverMainClass> [options]")
        // scalastyle:on println
        System.exit(-1)
    }
  }
} 
Example 62
Source File: HistoryServerArguments.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.history

import scala.annotation.tailrec

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
  extends Logging {
  private var propertiesFile: String = null

  parse(args.toList)

  @tailrec
  private def parse(args: List[String]): Unit = {
    if (args.length == 1) {
      setLogDirectory(args.head)
    } else {
      args match {
        case ("--dir" | "-d") :: value :: tail =>
          setLogDirectory(value)
          parse(tail)

        case ("--help" | "-h") :: tail =>
          printUsageAndExit(0)

        case ("--properties-file") :: value :: tail =>
          propertiesFile = value
          parse(tail)

        case Nil =>

        case _ =>
          printUsageAndExit(1)
      }
    }
  }

  private def setLogDirectory(value: String): Unit = {
    logWarning("Setting log directory through the command line is deprecated as of " +
      "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.")
    conf.set("spark.history.fs.logDirectory", value)
  }

   // This mutates the SparkConf, so all accesses to it must be made after this line
   Utils.loadDefaultSparkProperties(conf, propertiesFile)

  private def printUsageAndExit(exitCode: Int) {
    // scalastyle:off println
    System.err.println(
      """
      |Usage: HistoryServer [options]
      |
      |Options:
      |  DIR                         Deprecated; set spark.history.fs.logDirectory directly
      |  --dir DIR (-d DIR)          Deprecated; set spark.history.fs.logDirectory directly
      |  --properties-file FILE      Path to a custom Spark properties file.
      |                              Default is conf/spark-defaults.conf.
      |
      |Configuration options can be set by setting the corresponding JVM system property.
      |History Server options are always available; additional options depend on the provider.
      |
      |History Server options:
      |
      |  spark.history.ui.port              Port where server will listen for connections
      |                                     (default 18080)
      |  spark.history.acls.enable          Whether to enable view acls for all applications
      |                                     (default false)
      |  spark.history.provider             Name of history provider class (defaults to
      |                                     file system-based provider)
      |  spark.history.retainedApplications Max number of application UIs to keep loaded in memory
      |                                     (default 50)
      |FsHistoryProvider options:
      |
      |  spark.history.fs.logDirectory      Directory where app logs are stored
      |                                     (default: file:/tmp/spark-events)
      |  spark.history.fs.updateInterval    How often to reload log data from storage
      |                                     (in seconds, default: 10)
      |""".stripMargin)
    // scalastyle:on println
    System.exit(exitCode)
  }

} 
Example 63
Source File: LocalSparkCluster.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkConf
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.Utils


    for (workerNum <- 1 to numWorkers) {
      val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker,
        memoryPerWorker, masters, null, Some(workerNum), _conf)
      workerRpcEnvs += workerEnv
    }

    masters
  }

  def stop() {
    logInfo("Shutting down local Spark cluster.")
    // Stop the workers before the master so they don't get upset that it disconnected
    workerRpcEnvs.foreach(_.shutdown())
    masterRpcEnvs.foreach(_.shutdown())
    workerRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.clear()
    workerRpcEnvs.clear()
  }
} 
Example 64
Source File: JavaSerializer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io._
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}

private[spark] class JavaSerializationStream(
    out: OutputStream, counterReset: Int, extraDebugInfo: Boolean)
  extends SerializationStream {
  private val objOut = new ObjectOutputStream(out)
  private var counter = 0

  
@DeveloperApi
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
  private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
  private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)

  protected def this() = this(new SparkConf())  // For deserialization only

  override def newInstance(): SerializerInstance = {
    val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
    new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)
  }

  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
    out.writeInt(counterReset)
    out.writeBoolean(extraDebugInfo)
  }

  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
    counterReset = in.readInt()
    extraDebugInfo = in.readBoolean()
  }
} 
Example 65
Source File: BlacklistTracker.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.util.Utils

private[scheduler] object BlacklistTracker extends Logging {

  private val DEFAULT_TIMEOUT = "1h"

  
  def validateBlacklistConfs(conf: SparkConf): Unit = {

    def mustBePos(k: String, v: String): Unit = {
      throw new IllegalArgumentException(s"$k was $v, but must be > 0.")
    }

    Seq(
      config.MAX_TASK_ATTEMPTS_PER_EXECUTOR,
      config.MAX_TASK_ATTEMPTS_PER_NODE,
      config.MAX_FAILURES_PER_EXEC_STAGE,
      config.MAX_FAILED_EXEC_PER_NODE_STAGE
    ).foreach { config =>
      val v = conf.get(config)
      if (v <= 0) {
        mustBePos(config.key, v.toString)
      }
    }

    val timeout = getBlacklistTimeout(conf)
    if (timeout <= 0) {
      // first, figure out where the timeout came from, to include the right conf in the message.
      conf.get(config.BLACKLIST_TIMEOUT_CONF) match {
        case Some(t) =>
          mustBePos(config.BLACKLIST_TIMEOUT_CONF.key, timeout.toString)
        case None =>
          mustBePos(config.BLACKLIST_LEGACY_TIMEOUT_CONF.key, timeout.toString)
      }
    }

    val maxTaskFailures = conf.get(config.MAX_TASK_FAILURES)
    val maxNodeAttempts = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE)

    if (maxNodeAttempts >= maxTaskFailures) {
      throw new IllegalArgumentException(s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key} " +
        s"( = ${maxNodeAttempts}) was >= ${config.MAX_TASK_FAILURES.key} " +
        s"( = ${maxTaskFailures} ).  Though blacklisting is enabled, with this configuration, " +
        s"Spark will not be robust to one bad node.  Decrease " +
        s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " +
        s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}")
    }
  }
} 
Example 66
Source File: TaskResult.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.io._
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkEnv
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.storage.BlockId
import org.apache.spark.util.{AccumulatorV2, Utils}

// Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T]


  def value(resultSer: SerializerInstance = null): T = {
    if (valueObjectDeserialized) {
      valueObject
    } else {
      // This should not run when holding a lock because it may cost dozens of seconds for a large
      // value
      val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
      valueObject = ser.deserialize(valueBytes)
      valueObjectDeserialized = true
      valueObject
    }
  }
} 
Example 67
Source File: RDDInfo.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.util.Utils

@DeveloperApi
class RDDInfo(
    val id: Int,
    var name: String,
    val numPartitions: Int,
    var storageLevel: StorageLevel,
    val parentIds: Seq[Int],
    val callSite: String = "",
    val scope: Option[RDDOperationScope] = None)
  extends Ordered[RDDInfo] {

  var numCachedPartitions = 0
  var memSize = 0L
  var diskSize = 0L
  var externalBlockStoreSize = 0L

  def isCached: Boolean = (memSize + diskSize > 0) && numCachedPartitions > 0

  override def toString: String = {
    import Utils.bytesToString
    ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " +
      "MemorySize: %s; DiskSize: %s").format(
        name, id, storageLevel.toString, numCachedPartitions, numPartitions,
        bytesToString(memSize), bytesToString(diskSize))
  }

  override def compare(that: RDDInfo): Int = {
    this.id - that.id
  }
}

private[spark] object RDDInfo {
  def fromRdd(rdd: RDD[_]): RDDInfo = {
    val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd))
    val parentIds = rdd.dependencies.map(_.rdd.id)
    new RDDInfo(rdd.id, rddName, rdd.partitions.length,
      rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope)
  }
} 
Example 68
Source File: BlockManagerId.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


  def apply(
      execId: String,
      host: String,
      port: Int,
      topologyInfo: Option[String] = None): BlockManagerId =
    getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo))

  def apply(in: ObjectInput): BlockManagerId = {
    val obj = new BlockManagerId()
    obj.readExternal(in)
    getCachedBlockManagerId(obj)
  }

  val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()

  def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
    blockManagerIdCache.putIfAbsent(id, id)
    blockManagerIdCache.get(id)
  }
} 
Example 69
Source File: TopologyMapper.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


@DeveloperApi
class FileBasedTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
  val topologyFile = conf.getOption("spark.storage.replication.topologyFile")
  require(topologyFile.isDefined, "Please specify topology file via " +
    "spark.storage.replication.topologyFile for FileBasedTopologyMapper.")
  val topologyMap = Utils.getPropertiesFromFile(topologyFile.get)

  override def getTopologyForHost(hostname: String): Option[String] = {
    val topology = topologyMap.get(hostname)
    if (topology.isDefined) {
      logDebug(s"$hostname -> ${topology.get}")
    } else {
      logWarning(s"$hostname does not have any topology information")
    }
    topology
  }
} 
Example 70
Source File: BlockManagerSlaveEndpoint.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{MapOutputTracker, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}


private[storage]
class BlockManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv,
    blockManager: BlockManager,
    mapOutputTracker: MapOutputTracker)
  extends ThreadSafeRpcEndpoint with Logging {

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }

    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }

    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))

    case TriggerThreadDump =>
      context.reply(Utils.getThreadDump())
  }

  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.onSuccess { case response =>
      logDebug("Done " + actionMessage + ", response is " + response)
      context.reply(response)
      logDebug("Sent response: " + response + " to " + context.senderAddress)
    }
    future.onFailure { case t: Throwable =>
      logError("Error in " + actionMessage, t)
      context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 71
Source File: DiskStore.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{FileOutputStream, IOException, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode

import com.google.common.io.Closeables

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBuffer


  def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = {
    if (contains(blockId)) {
      throw new IllegalStateException(s"Block $blockId is already present in the disk store")
    }
    logDebug(s"Attempting to put block $blockId")
    val startTime = System.currentTimeMillis
    val file = diskManager.getFile(blockId)
    val fileOutputStream = new FileOutputStream(file)
    var threwException: Boolean = true
    try {
      writeFunc(fileOutputStream)
      threwException = false
    } finally {
      try {
        Closeables.close(fileOutputStream, threwException)
      } finally {
         if (threwException) {
          remove(blockId)
        }
      }
    }
    val finishTime = System.currentTimeMillis
    logDebug("Block %s stored as %s file on disk in %d ms".format(
      file.getName,
      Utils.bytesToString(file.length()),
      finishTime - startTime))
  }

  def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = {
    put(blockId) { fileOutputStream =>
      val channel = fileOutputStream.getChannel
      Utils.tryWithSafeFinally {
        bytes.writeFully(channel)
      } {
        channel.close()
      }
    }
  }

  def getBytes(blockId: BlockId): ChunkedByteBuffer = {
    val file = diskManager.getFile(blockId.name)
    val channel = new RandomAccessFile(file, "r").getChannel
    Utils.tryWithSafeFinally {
      // For small files, directly read rather than memory map
      if (file.length < minMemoryMapBytes) {
        val buf = ByteBuffer.allocate(file.length.toInt)
        channel.position(0)
        while (buf.remaining() != 0) {
          if (channel.read(buf) == -1) {
            throw new IOException("Reached EOF before filling buffer\n" +
              s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
          }
        }
        buf.flip()
        new ChunkedByteBuffer(buf)
      } else {
        new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length))
      }
    } {
      channel.close()
    }
  }

  def remove(blockId: BlockId): Boolean = {
    val file = diskManager.getFile(blockId.name)
    if (file.exists()) {
      val ret = file.delete()
      if (!ret) {
        logWarning(s"Error deleting ${file.getPath()}")
      }
      ret
    } else {
      false
    }
  }

  def contains(blockId: BlockId): Boolean = {
    val file = diskManager.getFile(blockId.name)
    file.exists()
  }
} 
Example 72
Source File: ZippedWithIndexRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

private[spark]
class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index
}


  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
      Array.empty
    } else if (n == 1) {
      Array(0L)
    } else {
      prev.context.runJob(
        prev,
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)
    }
  }

  override def getPartitions: Array[Partition] = {
    firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
  }

  override def getPreferredLocations(split: Partition): Seq[String] =
    firstParent[T].preferredLocations(split.asInstanceOf[ZippedWithIndexRDDPartition].prev)

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    val parentIter = firstParent[T].iterator(split.prev, context)
    Utils.getIteratorZipWithIndex(parentIter, split.startIndex)
  }
} 
Example 73
Source File: CartesianRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.util.Utils

private[spark]
class CartesianPartition(
    idx: Int,
    @transient private val rdd1: RDD[_],
    @transient private val rdd2: RDD[_],
    s1Index: Int,
    s2Index: Int
  ) extends Partition {
  var s1 = rdd1.partitions(s1Index)
  var s2 = rdd2.partitions(s2Index)
  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    s1 = rdd1.partitions(s1Index)
    s2 = rdd2.partitions(s2Index)
    oos.defaultWriteObject()
  }
}

private[spark]
class CartesianRDD[T: ClassTag, U: ClassTag](
    sc: SparkContext,
    var rdd1 : RDD[T],
    var rdd2 : RDD[U])
  extends RDD[(T, U)](sc, Nil)
  with Serializable {

  val numPartitionsInRdd2 = rdd2.partitions.length

  override def getPartitions: Array[Partition] = {
    // create the cross product split
    val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length)
    for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
      val idx = s1.index * numPartitionsInRdd2 + s2.index
      array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)
    }
    array
  }

  override def getPreferredLocations(split: Partition): Seq[String] = {
    val currSplit = split.asInstanceOf[CartesianPartition]
    (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct
  }

  override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = {
    val currSplit = split.asInstanceOf[CartesianPartition]
    for (x <- rdd1.iterator(currSplit.s1, context);
         y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
  }

  override def getDependencies: Seq[Dependency[_]] = List(
    new NarrowDependency(rdd1) {
      def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2)
    },
    new NarrowDependency(rdd2) {
      def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2)
    }
  )

  override def clearDependencies() {
    super.clearDependencies()
    rdd1 = null
    rdd2 = null
  }
} 
Example 74
Source File: UnionRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport}
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)
    oos.defaultWriteObject()
  }
}

object UnionRDD {
  private[spark] lazy val partitionEvalTaskSupport =
    new ForkJoinTaskSupport(new ForkJoinPool(8))
}

@DeveloperApi
class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  // visible for testing
  private[spark] val isPartitionListingParallel: Boolean =
    rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

  override def getPartitions: Array[Partition] = {
    val parRDDs = if (isPartitionListingParallel) {
      val parArray = rdds.par
      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
      parArray
    } else {
      rdds
    }
    val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length
    }
    deps
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)
  }

  override def getPreferredLocations(s: Partition): Seq[String] =
    s.asInstanceOf[UnionPartition[T]].preferredLocations()

  override def clearDependencies() {
    super.clearDependencies()
    rdds = null
  }
} 
Example 75
Source File: PartitionwiseSampledRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

private[spark]
class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index
}


private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    preservesPartitioning: Boolean,
    @transient private val seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
  }

  override def getPreferredLocations(split: Partition): Seq[String] =
    firstParent[T].preferredLocations(split.asInstanceOf[PartitionwiseSampledRDDPartition].prev)

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.setSeed(split.seed)
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
  }
} 
Example 76
Source File: PartitionerAwareUnionRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
import org.apache.spark.util.Utils


private[spark]
class PartitionerAwareUnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]]
  ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) {
  require(rdds.nonEmpty)
  require(rdds.forall(_.partitioner.isDefined))
  require(rdds.flatMap(_.partitioner).toSet.size == 1,
    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))

  override val partitioner = rdds.head.partitioner

  override def getPartitions: Array[Partition] = {
    val numPartitions = partitioner.get.numPartitions
    (0 until numPartitions).map { index =>
      new PartitionerAwareUnionRDDPartition(rdds, index)
    }.toArray
  }

  // Get the location where most of the partitions of parent RDDs are located
  override def getPreferredLocations(s: Partition): Seq[String] = {
    logDebug("Finding preferred location for " + this + ", partition " + s.index)
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    val locations = rdds.zip(parentPartitions).flatMap {
      case (rdd, part) =>
        val parentLocations = currPrefLocs(rdd, part)
        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
        parentLocations
    }
    val location = if (locations.isEmpty) {
      None
    } else {
      // Find the location that maximum number of parent partitions prefer
      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
    }
    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)
    location.toSeq
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    rdds.zip(parentPartitions).iterator.flatMap {
      case (rdd, p) => rdd.iterator(p, context)
    }
  }

  override def clearDependencies() {
    super.clearDependencies()
    rdds = null
  }

  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
    rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host)
  }
} 
Example 77
Source File: PythonBroadcastSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{File, PrintWriter}

import scala.io.Source

import org.scalatest.Matchers

import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils

// This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize
// a PythonBroadcast:
class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext {
  test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") {
    val tempDir = Utils.createTempDir()
    val broadcastedString = "Hello, world!"
    def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = {
      val source = Source.fromFile(broadcast.path)
      val contents = source.mkString
      source.close()
      contents should be (broadcastedString)
    }
    try {
      val broadcastDataFile: File = {
        val file = new File(tempDir, "broadcastData")
        val printWriter = new PrintWriter(file)
        printWriter.write(broadcastedString)
        printWriter.close()
        file
      }
      val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath)
      assertBroadcastIsValid(broadcast)
      val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
      val deserializedBroadcast =
        Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance())
      assertBroadcastIsValid(deserializedBroadcast)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
} 
Example 78
Source File: PythonRunnerSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils

class PythonRunnerSuite extends SparkFunSuite {

  // Test formatting a single path to be added to the PYTHONPATH
  test("format path") {
    assert(PythonRunner.formatPath("spark.py") === "spark.py")
    assert(PythonRunner.formatPath("file:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("file:///spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:///spark.py") === "/spark.py")
    if (Utils.isWindows) {
      assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a\\b\\spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a b\\spark.py", testWindows = true) ===
        "C:/a b/spark.py")
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("one:two") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:s3:xtremeFS") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:/path/to/some.py") }
  }

  // Test formatting multiple comma-separated paths to be added to the PYTHONPATH
  test("format paths") {
    assert(PythonRunner.formatPaths("spark.py") === Array("spark.py"))
    assert(PythonRunner.formatPaths("file:/spark.py") === Array("/spark.py"))
    assert(PythonRunner.formatPaths("file:/app.py,local:/spark.py") ===
      Array("/app.py", "/spark.py"))
    assert(PythonRunner.formatPaths("me.py,file:/you.py,local:/we.py") ===
      Array("me.py", "/you.py", "/we.py"))
    if (Utils.isWindows) {
      assert(PythonRunner.formatPaths("C:\\a\\b\\spark.py", testWindows = true) ===
        Array("C:/a/b/spark.py"))
      assert(PythonRunner.formatPaths("C:\\free.py,pie.py", testWindows = true) ===
        Array("C:/free.py", "pie.py"))
      assert(PythonRunner.formatPaths("lovely.py,C:\\free.py,file:/d:/fry.py",
        testWindows = true) ===
        Array("lovely.py", "C:/free.py", "d:/fry.py"))
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("one:two,three") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("two,three,four:five:six") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
  }
} 
Example 79
Source File: CommandUtilsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import org.scalatest.{Matchers, PrivateMethodTester}

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.Command
import org.apache.spark.util.Utils

class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTester {

  test("set libraryPath correctly") {
    val appId = "12345-worker321-9876"
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val cmd = new Command("mainClass", Seq(), Map(), Seq(), Seq("libraryPathToB"), Seq())
    val builder = CommandUtils.buildProcessBuilder(
      cmd, new SecurityManager(new SparkConf), 512, sparkHome, t => t)
    val libraryPath = Utils.libraryPathEnvName
    val env = builder.environment
    env.keySet should contain(libraryPath)
    assert(env.get(libraryPath).startsWith("libraryPathToB"))
  }

  test("auth secret shouldn't appear in java opts") {
    val buildLocalCommand = PrivateMethod[Command]('buildLocalCommand)
    val conf = new SparkConf
    val secret = "This is the secret sauce"
    // set auth secret
    conf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, secret)
    val command = new Command("mainClass", Seq(), Map(), Seq(), Seq("lib"),
      Seq("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF + "=" + secret))

    // auth is not set
    var cmd = CommandUtils invokePrivate buildLocalCommand(
      command, new SecurityManager(conf), (t: String) => t, Seq(), Map())
    assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF)))
    assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET))

    // auth is set to false
    conf.set(SecurityManager.SPARK_AUTH_CONF, "false")
    cmd = CommandUtils invokePrivate buildLocalCommand(
      command, new SecurityManager(conf), (t: String) => t, Seq(), Map())
    assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF)))
    assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET))

    // auth is set to true
    conf.set(SecurityManager.SPARK_AUTH_CONF, "true")
    cmd = CommandUtils invokePrivate buildLocalCommand(
      command, new SecurityManager(conf), (t: String) => t, Seq(), Map())
    assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF)))
    assert(cmd.environment(SecurityManager.ENV_AUTH_SECRET) === secret)
  }
} 
Example 80
Source File: HistoryServerArgumentsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.history

import java.io.File
import java.nio.charset.StandardCharsets._

import com.google.common.io.Files

import org.apache.spark._
import org.apache.spark.util.Utils

class HistoryServerArgumentsSuite extends SparkFunSuite {

  private val logDir = new File("src/test/resources/spark-events")
  private val conf = new SparkConf()
    .set("spark.history.fs.logDirectory", logDir.getAbsolutePath)
    .set("spark.history.fs.updateInterval", "1")
    .set("spark.testing", "true")

  test("No Arguments Parsing") {
    val argStrings = Array.empty[String]
    val hsa = new HistoryServerArguments(conf, argStrings)
    assert(conf.get("spark.history.fs.logDirectory") === logDir.getAbsolutePath)
    assert(conf.get("spark.history.fs.updateInterval") === "1")
    assert(conf.get("spark.testing") === "true")
  }

  test("Directory Arguments Parsing --dir or -d") {
    val argStrings = Array("--dir", "src/test/resources/spark-events1")
    val hsa = new HistoryServerArguments(conf, argStrings)
    assert(conf.get("spark.history.fs.logDirectory") === "src/test/resources/spark-events1")
  }

  test("Directory Param can also be set directly") {
    val argStrings = Array("src/test/resources/spark-events2")
    val hsa = new HistoryServerArguments(conf, argStrings)
    assert(conf.get("spark.history.fs.logDirectory") === "src/test/resources/spark-events2")
  }

  test("Properties File Arguments Parsing --properties-file") {
    val tmpDir = Utils.createTempDir()
    val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir)
    try {
      Files.write("spark.test.CustomPropertyA blah\n" +
        "spark.test.CustomPropertyB notblah\n", outFile, UTF_8)
      val argStrings = Array("--properties-file", outFile.getAbsolutePath)
      val hsa = new HistoryServerArguments(conf, argStrings)
      assert(conf.get("spark.test.CustomPropertyA") === "blah")
      assert(conf.get("spark.test.CustomPropertyB") === "notblah")
    } finally {
      Utils.deleteRecursively(tmpDir)
    }
  }

} 
Example 81
Source File: SortShuffleSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.io.File

import scala.collection.JavaConverters._

import org.apache.commons.io.FileUtils
import org.apache.commons.io.filefilter.TrueFileFilter
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.util.Utils

class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {

  // This test suite should run all tests in ShuffleSuite with sort-based shuffle.

  private var tempDir: File = _

  override def beforeAll() {
    super.beforeAll()
    conf.set("spark.shuffle.manager", "sort")
  }

  override def beforeEach(): Unit = {
    super.beforeEach()
    tempDir = Utils.createTempDir()
    conf.set("spark.local.dir", tempDir.getAbsolutePath)
  }

  override def afterEach(): Unit = {
    try {
      Utils.deleteRecursively(tempDir)
    } finally {
      super.afterEach()
    }
  }

  test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") {
    sc = new SparkContext("local", "test", conf)
    // Create a shuffled RDD and verify that it actually uses the new serialized map output path
    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
      .setSerializer(new KryoSerializer(conf))
    val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
    assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
    ensureFilesAreCleanedUp(shuffledRdd)
  }

  test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") {
    sc = new SparkContext("local", "test", conf)
    // Create a shuffled RDD and verify that it actually uses the old deserialized map output path
    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
      .setSerializer(new JavaSerializer(conf))
    val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
    assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
    ensureFilesAreCleanedUp(shuffledRdd)
  }

  private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = {
    def getAllFiles: Set[File] =
      FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
    val filesBeforeShuffle = getAllFiles
    // Force the shuffle to be performed
    shuffledRdd.count()
    // Ensure that the shuffle actually created files that will need to be cleaned up
    val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
    filesCreatedByShuffle.map(_.getName) should be
    Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
    // Check that the cleanup actually removes the files
    sc.env.blockManager.master.removeShuffle(0, blocking = true)
    for (file <- filesCreatedByShuffle) {
      assert (!file.exists(), s"Shuffle file $file was not cleaned up")
    }
  }
} 
Example 82
Source File: KryoSerializerDistributedSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import com.esotericsoftware.kryo.Kryo

import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.serializer.KryoDistributedTest._
import org.apache.spark.util.Utils

class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContext {

  test("kryo objects are serialised consistently in different processes") {
    val conf = new SparkConf(false)
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
      .set(config.MAX_TASK_FAILURES, 1)
      .set(config.BLACKLIST_ENABLED, false)

    val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
    conf.setJars(List(jar.getPath))

    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
    val original = Thread.currentThread.getContextClassLoader
    val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
    SparkEnv.get.serializer.setDefaultClassLoader(loader)

    val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()

    // Randomly mix the keys so that the join below will require a shuffle with each partition
    // sending data to multiple other partitions.
    val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}

    // Join the two RDDs, and force evaluation
    assert(shuffledRDD.join(cachedRDD).collect().size == 1)
  }
}

object KryoDistributedTest {
  class MyCustomClass

  class AppJarRegistrator extends KryoRegistrator {
    override def registerClasses(k: Kryo) {
      val classLoader = Thread.currentThread.getContextClassLoader
      // scalastyle:off classforname
      k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
      // scalastyle:on classforname
    }
  }

  object AppJarRegistrator {
    val customClassName = "KryoSerializerDistributedSuiteCustomClass"
  }
} 
Example 83
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Seconds, Span}

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils


class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    super.beforeAll()
    val conf = new SparkConf()
      .set("spark.hadoop.outputCommitCoordination.enabled", "true")
      .set("spark.hadoop.mapred.output.committer.class",
        classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName)
    sc = new SparkContext("local[2, 4]", "test", conf)
  }

  test("exception thrown in OutputCommitter.commitTask()") {
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {
        Utils.deleteRecursively(tempDir)
      }
    }
  }
}

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new java.io.FileNotFoundException("Intentional exception")
    }
    super.commitTask(context)
  }
} 
Example 84
Source File: DriverSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.io.File

import org.scalatest.concurrent.Timeouts
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.time.SpanSugar._

import org.apache.spark.util.Utils

class DriverSuite extends SparkFunSuite with Timeouts {

  ignore("driver should exit after finishing without cleanup (SPARK-530)") {
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val masters = Table("master", "local", "local-cluster[2,1,1024]")
    forAll(masters) { (master: String) =>
      val process = Utils.executeCommand(
        Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master),
        new File(sparkHome),
        Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
      failAfter(60 seconds) { process.waitFor() }
      // Ensure we still kill the process in case it timed out
      process.destroy()
    }
  }
}


object DriverWithoutCleanup {
  def main(args: Array[String]) {
    Utils.configTestLog4j("INFO")
    val conf = new SparkConf
    val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf)
    sc.parallelize(1 to 100, 4).count()
  }
} 
Example 85
Source File: DiskBlockManagerSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    try {
      Utils.deleteRecursively(rootDir0)
      Utils.deleteRecursively(rootDir1)
    } finally {
      super.afterAll()
    }
  }

  override def beforeEach() {
    super.beforeEach()
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
  }

  override def afterEach() {
    try {
      diskBlockManager.stop()
    } finally {
      super.afterEach()
    }
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 86
Source File: DiskStoreSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.io.ChunkedByteBuffer
import org.apache.spark.util.Utils

class DiskStoreSuite extends SparkFunSuite {

  test("reads of memory-mapped and non memory-mapped files are equivalent") {
    // It will cause error when we tried to re-open the filestore and the
    // memory-mapped byte buffer tot he file has not been GC on Windows.
    assume(!Utils.isWindows)
    val confKey = "spark.storage.memoryMapThreshold"

    // Create a non-trivial (not all zeros) byte array
    val bytes = Array.tabulate[Byte](1000)(_.toByte)
    val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes))

    val blockId = BlockId("rdd_1_2")
    val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true)

    val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager)
    diskStoreMapped.putBytes(blockId, byteBuffer)
    val mapped = diskStoreMapped.getBytes(blockId)
    assert(diskStoreMapped.remove(blockId))

    val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager)
    diskStoreNotMapped.putBytes(blockId, byteBuffer)
    val notMapped = diskStoreNotMapped.getBytes(blockId)

    // Not possible to do isInstanceOf due to visibility of HeapByteBuffer
    assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")),
      "Expected HeapByteBuffer for un-mapped read")
    assert(mapped.getChunks().forall(_.isInstanceOf[MappedByteBuffer]),
      "Expected MappedByteBuffer for mapped read")

    def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = {
      val array = new Array[Byte](in.remaining())
      in.get(array)
      array
    }

    assert(Arrays.equals(mapped.toArray, bytes))
    assert(Arrays.equals(notMapped.toArray, bytes))
  }
} 
Example 87
Source File: TopologyMapperSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileOutputStream}

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark._
import org.apache.spark.util.Utils

class TopologyMapperSuite  extends SparkFunSuite
  with Matchers
  with BeforeAndAfter
  with LocalSparkContext {

  test("File based Topology Mapper") {
    val numHosts = 100
    val numRacks = 4
    val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap
    val propsFile = createPropertiesFile(props)

    val sparkConf = (new SparkConf(false))
    sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath)
    val topologyMapper = new FileBasedTopologyMapper(sparkConf)

    props.foreach {case (host, topology) =>
      val obtainedTopology = topologyMapper.getTopologyForHost(host)
      assert(obtainedTopology.isDefined)
      assert(obtainedTopology.get === topology)
    }

    // we get None for hosts not in the file
    assert(topologyMapper.getTopologyForHost("host").isEmpty)

    cleanup(propsFile)
  }

  def createPropertiesFile(props: Map[String, String]): File = {
    val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile
    val fileOS = new FileOutputStream(testFile)
    props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)}
    fileOS.close
    testFile
  }

  def cleanup(testFile: File): Unit = {
    testFile.getParentFile.listFiles.filter { file =>
      file.getName.startsWith(testFile.getName)
    }.foreach { _.delete() }
  }

} 
Example 88
Source File: LocalDirsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.{SparkConfWithEnv, Utils}


class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter {

  before {
    Utils.clearLocalRootDirs()
  }

  test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
    // Regression test for SPARK-2974
    assert(!new File("/NONEXISTENT_DIR").exists())
    val conf = new SparkConf(false)
      .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

  test("SPARK_LOCAL_DIRS override also affects driver") {
    // Regression test for SPARK-2975
    assert(!new File("/NONEXISTENT_DIR").exists())
    // spark.local.dir only contains invalid directories, but that's not a problem since
    // SPARK_LOCAL_DIRS will override it on both the driver and workers:
    val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir")))
      .set("spark.local.dir", "/NONEXISTENT_PATH")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

} 
Example 89
Source File: JdbcRDDSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.sql._

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils

class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  before {
    Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver")
    val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
    try {

      try {
        val create = conn.createStatement
        create.execute("""
          CREATE TABLE FOO(
            ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
            DATA INTEGER
          )""")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
        (1 to 100).foreach { i =>
          insert.setInt(1, i * 2)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

      try {
        val create = conn.createStatement
        create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)")
        (1 to 100).foreach { i =>
          insert.setLong(1, 100000000000000000L +  4000000000000000L * i)
          insert.setInt(2, i)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

    } finally {
      conn.close()
    }
  }

  test("basic functionality") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
      1, 100, 3,
      (r: ResultSet) => { r.getInt(1) } ).cache()

    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 10100)
  }

  test("large id overflow") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?",
      1131544775L, 567279358897692673L, 20,
      (r: ResultSet) => { r.getInt(1) } ).cache()
    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 5050)
  }

  after {
    try {
      DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
    } catch {
      case se: SQLException if se.getSQLState == "08006" =>
        // Normal single database shutdown
        // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
    }
  }
} 
Example 90
Source File: SparkFunSuite.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

// scalastyle:off
import java.io.File

import scala.annotation.tailrec
import org.apache.log4j.{Appender, Level, Logger}
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, Outcome, Suite}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.util.{AccumulatorContext, Utils}


  protected def withLogAppender(
    appender: Appender,
    loggerName: Option[String] = None,
    level: Option[Level] = None)(
    f: => Unit): Unit = {
    val logger = loggerName.map(Logger.getLogger).getOrElse(Logger.getRootLogger)
    val restoreLevel = logger.getLevel
    logger.addAppender(appender)
    if (level.isDefined) {
      logger.setLevel(level.get)
    }
    try f finally {
      logger.removeAppender(appender)
      if (level.isDefined) {
        logger.setLevel(restoreLevel)
      }
    }
  }
} 
Example 91
Source File: HBasePartitioner.scala    From Backup-Repo   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hbase

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
import org.apache.spark.{Partitioner, SparkEnv}

object HBasePartitioner {
  implicit object HBaseRawOrdering extends Ordering[HBaseRawType] {
    def compare(a: HBaseRawType, b: HBaseRawType) = Bytes.compareTo(a, b)
  }
}

class HBasePartitioner (var splitKeys: Array[HBaseRawType]) extends Partitioner {
  import HBasePartitioner.HBaseRawOrdering

  type t = HBaseRawType

  lazy private val len = splitKeys.length

  // For pre-split table splitKeys(0) = bytes[0], to remove it,
  // otherwise partition 0 always be empty and
  // we will miss the last region's date when bulk load
  lazy private val realSplitKeys = if (splitKeys.isEmpty) splitKeys else splitKeys.tail

  def numPartitions = if (len == 0) 1 else len

  @transient private val binarySearch: ((Array[t], t) => Int) = CollectionsUtils.makeBinarySearch[t]

  def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[t]
    var partition = 0
    if (len <= 128 && len > 0) {
      // If we have less than 128 partitions naive search
      val ordering = implicitly[Ordering[t]]
      while (partition < realSplitKeys.length && ordering.gt(k, realSplitKeys(partition))) {
        partition += 1
      }
    } else {
      // Determine which binary search method to use only once.
      partition = binarySearch(realSplitKeys, k)
      // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) {
        partition = -partition - 1
      }
      if (partition > realSplitKeys.length) {
        partition = realSplitKeys.length
      }
    }
    partition
  }

  override def equals(other: Any): Boolean = other match {
    case r: HBasePartitioner =>
      r.splitKeys.sameElements(splitKeys)
    case _ =>
      false
  }

  override def hashCode(): Int = {
    val prime = 31
    var result = 1
    var i = 0
    while (i < splitKeys.length) {
      result = prime * result + splitKeys(i).hashCode
      i += 1
    }
    result = prime * result
    result
  }
} 
Example 92
Source File: BytecodeUtils.scala    From graphx-algorithm   with GNU General Public License v2.0 5 votes vote down vote up
package org.apache.spark.graphx.util

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable.HashSet
import scala.language.existentials

import org.apache.spark.util.Utils

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._



  private class MethodInvocationFinder(className: String, methodName: String)
    extends ClassVisitor(ASM4) {

    val methodsInvoked = new HashSet[(Class[_], String)]

    override def visitMethod(access: Int, name: String, desc: String,
                             sig: String, exceptions: Array[String]): MethodVisitor = {
      if (name == methodName) {
        new MethodVisitor(ASM4) {
          override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
            if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
              if (!skipClass(owner)) {
                methodsInvoked.add((Class.forName(owner.replace("/", ".")), name))
              }
            }
          }
        }
      } else {
        null
      }
    }
  }
} 
Example 93
Source File: MapJoinPartitionsRDD.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.util.Utils

class MapJoinPartitionsPartition(
    idx: Int,
    @transient private val rdd1: RDD[_],
    @transient private val rdd2: RDD[_],
    s2IdxArr: Array[Int]) extends Partition {

  var s1 = rdd1.partitions(idx)
  var s2Arr = s2IdxArr.map(s2Idx => rdd2.partitions(s2Idx))
  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    s1 = rdd1.partitions(idx)
    s2Arr = s2IdxArr.map(s2Idx => rdd2.partitions(s2Idx))
    oos.defaultWriteObject()
  }
}

class MapJoinPartitionsRDD[A: ClassTag, B: ClassTag, V: ClassTag](
    sc: SparkContext,
    var idxF: (Int) => Array[Int],
    var f: (Int, Iterator[A], Array[(Int, Iterator[B])]) => Iterator[V],
    var rdd1: RDD[A],
    var rdd2: RDD[B])
  extends RDD[V](sc, Nil) {

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](rdd1.partitions.length)
    for (s1 <- rdd1.partitions) {
      val idx = s1.index
      array(idx) = new MapJoinPartitionsPartition(idx, rdd1, rdd2, idxF(idx))
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = List(
    new OneToOneDependency(rdd1),
    new NarrowDependency(rdd2) {
      override def getParents(partitionId: Int): Seq[Int] = {
        idxF(partitionId)
      }
    }
  )

  override def getPreferredLocations(s: Partition): Seq[String] = {
    val fp = firstParent[A]
    // println(s"pref loc: ${fp.preferredLocations(fp.partitions(s.index))}")
    fp.preferredLocations(fp.partitions(s.index))
  }

  override def compute(split: Partition, context: TaskContext): Iterator[V] = {
    val currSplit = split.asInstanceOf[MapJoinPartitionsPartition]
    f(currSplit.s1.index, rdd1.iterator(currSplit.s1, context),
      currSplit.s2Arr.map(s2 => (s2.index, rdd2.iterator(s2, context)))
    )
  }

  override def clearDependencies() {
    super.clearDependencies()
    rdd1 = null
    rdd2 = null
    idxF = null
    f = null
  }
} 
Example 94
Source File: MapJoinPartitionsRDDV2.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import org.apache.spark.serializer.Serializer
import org.apache.spark.{TaskContext, _}
import org.apache.spark.util.Utils

import scala.reflect.ClassTag

class MapJoinPartitionsPartitionV2(
    idx: Int,
    @transient private val rdd1: RDD[_],
    @transient private val rdd2: RDD[_],
    s2IdxArr: Array[Int]) extends Partition {

  var s1 = rdd1.partitions(idx)
  var s2Arr = s2IdxArr.map(s2Idx => rdd2.partitions(s2Idx))
  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    s1 = rdd1.partitions(idx)
    s2Arr = s2IdxArr.map(s2Idx => rdd2.partitions(s2Idx))
    oos.defaultWriteObject()
  }
}

class MapJoinPartitionsRDDV2[A: ClassTag, B: ClassTag, V: ClassTag](
    sc: SparkContext,
    var idxF: (Int) => Array[Int],
    var f: (Int, Iterator[A], Array[(Int, Iterator[B])]) => Iterator[V],
    var rdd1: RDD[A],
    var rdd2: RDD[B],
    preservesPartitioning: Boolean = false)
  extends RDD[V](sc, Nil) {

  var rdd2WithPid = rdd2.mapPartitionsWithIndex((pid, iter) => iter.map(x => (pid, x)))

  private val serializer: Serializer = SparkEnv.get.serializer

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](rdd1.partitions.length)
    for (s1 <- rdd1.partitions) {
      val idx = s1.index
      array(idx) = new MapJoinPartitionsPartitionV2(idx, rdd1, rdd2, idxF(idx))
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = List(
    new OneToOneDependency(rdd1),
    new ShuffleDependency[Int, B, B](
      rdd2WithPid.asInstanceOf[RDD[_ <: Product2[Int, B]]],
      new IdentityPartitioner(rdd2WithPid.getNumPartitions), serializer)
  )

  override def getPreferredLocations(s: Partition): Seq[String] = {
    val fp = firstParent[A]
    // println(s"pref loc: ${fp.preferredLocations(fp.partitions(s.index))}")
    fp.preferredLocations(fp.partitions(s.index))
  }

  override def compute(split: Partition, context: TaskContext): Iterator[V] = {
    val currSplit = split.asInstanceOf[MapJoinPartitionsPartitionV2]
    val rdd2Dep = dependencies(1).asInstanceOf[ShuffleDependency[Int, Any, Any]]
    val rdd2PartIter = currSplit.s2Arr.map(s2 => (s2.index,
      SparkEnv.get.shuffleManager
        .getReader[Int, B](rdd2Dep.shuffleHandle, s2.index, s2.index + 1, context)
        .read().map(x => x._2)
      ))
    val rdd1Iter = rdd1.iterator(currSplit.s1, context)
    f(currSplit.s1.index, rdd1Iter, rdd2PartIter)
  }

  override def clearDependencies() {
    super.clearDependencies()
    rdd1 = null
    rdd2 = null
    rdd2WithPid = null
    idxF = null
    f = null
  }
}

private[spark] class IdentityPartitioner(val numParts: Int) extends Partitioner {
  require(numPartitions > 0)
  override def getPartition(key: Any): Int = key.asInstanceOf[Int]
  override def numPartitions: Int = numParts
} 
Example 95
Source File: QueryPartitionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import java.io.File
import java.sql.Timestamp

import com.google.common.io.Files
import org.apache.hadoop.fs.FileSystem

import org.apache.spark.internal.config._
import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils

class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  import spark.implicits._

  private def queryWhenPathNotExist(): Unit = {
    withTempView("testData") {
      withTable("table_with_partition", "createAndInsertTest") {
        withTempDir { tmpDir =>
          val testData = sparkContext.parallelize(
            (1 to 10).map(i => TestData(i, i.toString))).toDF()
          testData.createOrReplaceTempView("testData")

          // create the table for test
          sql(s"CREATE TABLE table_with_partition(key int,value string) " +
              s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='1') " +
              "SELECT key,value FROM testData")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='2') " +
              "SELECT key,value FROM testData")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='3') " +
              "SELECT key,value FROM testData")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='4') " +
              "SELECT key,value FROM testData")

          // test for the exist path
          checkAnswer(sql("select key,value from table_with_partition"),
            testData.union(testData).union(testData).union(testData))

          // delete the path of one partition
          tmpDir.listFiles
              .find { f => f.isDirectory && f.getName().startsWith("ds=") }
              .foreach { f => Utils.deleteRecursively(f) }

          // test for after delete the path
          checkAnswer(sql("select key,value from table_with_partition"),
            testData.union(testData).union(testData))
        }
      }
    }
  }

  test("SPARK-5068: query data when path doesn't exist") {
    withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "true") {
      queryWhenPathNotExist()
    }
  }

  test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") {
    withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "false") {
      sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true")
      queryWhenPathNotExist()
    }
  }

  test("SPARK-21739: Cast expression should initialize timezoneId") {
    withTable("table_with_timestamp_partition") {
      sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
      sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
        "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")

      // test for Cast expression in TableReader
      checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
        Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))

      // test for Cast expression in HiveTableScanExec
      checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
        "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
    }
  }
} 
Example 96
Source File: HiveClientBuilder.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import java.io.File

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.util.VersionInfo

import org.apache.spark.SparkConf
import org.apache.spark.util.Utils

private[client] object HiveClientBuilder {
  // In order to speed up test execution during development or in Jenkins, you can specify the path
  // of an existing Ivy cache:
  private val ivyPath: Option[String] = {
    sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse(
      Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath))
  }

  private def buildConf(extraConf: Map[String, String]) = {
    lazy val warehousePath = Utils.createTempDir()
    lazy val metastorePath = Utils.createTempDir()
    metastorePath.delete()
    extraConf ++ Map(
      "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true",
      "hive.metastore.warehouse.dir" -> warehousePath.toString)
  }

  // for testing only
  def buildClient(
      version: String,
      hadoopConf: Configuration,
      extraConf: Map[String, String] = Map.empty,
      sharesHadoopClasses: Boolean = true): HiveClient = {
    IsolatedClientLoader.forVersion(
      hiveMetastoreVersion = version,
      hadoopVersion = VersionInfo.getVersion,
      sparkConf = new SparkConf(),
      hadoopConf = hadoopConf,
      config = buildConf(extraConf),
      ivyPath = ivyPath,
      sharesHadoopClasses = sharesHadoopClasses).createClient()
  }
} 
Example 97
Source File: SparkSQLEnv.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.PrintStream

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils


  def stop() {
    logDebug("Shutting down Spark SQL Environment")
    // Stop the SparkContext
    if (SparkSQLEnv.sparkContext != null) {
      sparkContext.stop()
      sparkContext = null
      sqlContext = null
    }
  }
} 
Example 98
Source File: HiveMetastoreLazyInitializationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.util.Utils

class HiveMetastoreLazyInitializationSuite extends SparkFunSuite {

  test("lazily initialize Hive client") {
    val spark = SparkSession.builder()
      .appName("HiveMetastoreLazyInitializationSuite")
      .master("local[2]")
      .enableHiveSupport()
      .config("spark.hadoop.hive.metastore.uris", "thrift://127.0.0.1:11111")
      .getOrCreate()
    val originalLevel = org.apache.log4j.Logger.getRootLogger().getLevel
    try {
      // Avoid outputting a lot of expected warning logs
      spark.sparkContext.setLogLevel("error")

      // We should be able to run Spark jobs without Hive client.
      assert(spark.sparkContext.range(0, 1).count() === 1)

      // We should be able to use Spark SQL if no table references.
      assert(spark.sql("select 1 + 1").count() === 1)
      assert(spark.range(0, 1).count() === 1)

      // We should be able to use fs
      val path = Utils.createTempDir()
      path.delete()
      try {
        spark.range(0, 1).write.parquet(path.getAbsolutePath)
        assert(spark.read.parquet(path.getAbsolutePath).count() === 1)
      } finally {
        Utils.deleteRecursively(path)
      }

      // Make sure that we are not using the local derby metastore.
      val exceptionString = Utils.exceptionString(intercept[AnalysisException] {
        spark.sql("show tables")
      })
      for (msg <- Seq(
        "show tables",
        "Could not connect to meta store",
        "org.apache.thrift.transport.TTransportException",
        "Connection refused")) {
        exceptionString.contains(msg)
      }
    } finally {
      spark.sparkContext.setLogLevel(originalLevel.toString)
      spark.stop()
    }
  }
} 
Example 99
Source File: JdbcConnectionUriSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.sql.DriverManager

import org.apache.hive.jdbc.HiveDriver

import org.apache.spark.util.Utils

class JdbcConnectionUriSuite extends HiveThriftServer2Test {
  Utils.classForName(classOf[HiveDriver].getCanonicalName)

  override def mode: ServerMode.Value = ServerMode.binary

  val JDBC_TEST_DATABASE = "jdbc_test_database"
  val USER = System.getProperty("user.name")
  val PASSWORD = ""

  override protected def beforeAll(): Unit = {
    super.beforeAll()

    val jdbcUri = s"jdbc:hive2://localhost:$serverPort/"
    val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD)
    val statement = connection.createStatement()
    statement.execute(s"CREATE DATABASE $JDBC_TEST_DATABASE")
    connection.close()
  }

  override protected def afterAll(): Unit = {
    try {
      val jdbcUri = s"jdbc:hive2://localhost:$serverPort/"
      val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD)
      val statement = connection.createStatement()
      statement.execute(s"DROP DATABASE $JDBC_TEST_DATABASE")
      connection.close()
    } finally {
      super.afterAll()
    }
  }

  test("SPARK-17819 Support default database in connection URIs") {
    val jdbcUri = s"jdbc:hive2://localhost:$serverPort/$JDBC_TEST_DATABASE"
    val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD)
    val statement = connection.createStatement()
    try {
      val resultSet = statement.executeQuery("select current_database()")
      resultSet.next()
      assert(resultSet.getString(1) === JDBC_TEST_DATABASE)
    } finally {
      statement.close()
      connection.close()
    }
  }
} 
Example 100
Source File: DataSourceManagerFactory.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql

import java.util.ServiceLoader

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.util.Utils

object DataSourceManagerFactory {

  def create(
      datasourceType: String,
      conf: SparkConf,
      hadoopConf: Configuration): DataSourceManager = {
    val loader = Utils.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[DataSourceManager], loader)
    var cls: Class[_] = null
    // As we use ServiceLoader to support creating any user provided DataSourceManager here,
    // META-INF/services/org.apache.spark.sql.sources.DataSourceRegister must be packaged properly
    // in user's jar, and the implementation of DataSourceManager must have a public parameterless
    // constructor. For scala language, def this() = this(null...) just work.
    try {
      cls = serviceLoader.asScala
        .filter(_.shortName().equals(datasourceType))
        .toList match {
        case head :: Nil =>
          head.getClass
        case _ =>
          throw new SparkException(s"error when instantiate datasource ${datasourceType}")
      }
    } catch {
      case _: Exception =>
        throw new SparkException(
          s"""Can't find corresponding DataSourceManager for ${datasourceType} type,
             |please check
             |1. META-INF/services/org.apache.spark.sql.sources.DataSourceRegister is packaged
             |2. your implementation of DataSourceManager's shortname is ${datasourceType}
             |3. your implementation of DataSourceManager must have a public parameterless
             |   constructor. For scala language, def this() = this(null, null, ...) just work.
           """.stripMargin)
    }
    try {
      val constructor = cls.getConstructor(classOf[SparkConf], classOf[Configuration])
      val newHadoopConf = new Configuration(hadoopConf)
      constructor.newInstance(conf, newHadoopConf).asInstanceOf[DataSourceManager]
    } catch {
      case _: NoSuchMethodException =>
        try {
          cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[DataSourceManager]
        } catch {
          case _: NoSuchMethodException =>
            cls.getConstructor().newInstance().asInstanceOf[DataSourceManager]
        }
    }
  }
} 
Example 101
Source File: UDTRegistration.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.types

import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


  def getUDTFor(userClass: String): Option[Class[_]] = {
    udtMap.get(userClass).map { udtClassName =>
      if (Utils.classIsLoadable(udtClassName)) {
        val udtClass = Utils.classForName(udtClassName)
        if (classOf[UserDefinedType[_]].isAssignableFrom(udtClass)) {
          udtClass
        } else {
          throw new SparkException(
            s"${udtClass.getName} is not an UserDefinedType. Please make sure registering " +
              s"an UserDefinedType for ${userClass}")
        }
      } else {
        throw new SparkException(
          s"Can not load in UserDefinedType ${udtClassName} for user class ${userClass}.")
      }
    }
  }
} 
Example 102
Source File: randomExpressions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom


// scalastyle:off line.size.limit
@ExpressionDescription(
  usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""",
  examples = """
    Examples:
      > SELECT _FUNC_();
       -0.3254147983080288
      > SELECT _FUNC_(0);
       1.1164209726833079
      > SELECT _FUNC_(null);
       1.1164209726833079
  """,
  note = "The function is non-deterministic in general case.")
// scalastyle:on line.size.limit
case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed {

  def this() = this(Literal(Utils.random.nextLong(), LongType))

  override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType))

  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = classOf[XORShiftRandom].getName
    val rngTerm = ctx.addMutableState(className, "rng")
    ctx.addPartitionInitializationStatement(
      s"$rngTerm = new $className(${seed}L + partitionIndex);")
    ev.copy(code = code"""
      final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
      isNull = FalseLiteral)
  }

  override def freshCopy(): Randn = Randn(child)
}

object Randn {
  def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
} 
Example 103
Source File: CodeGeneratorWithInterpretedFallback.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils


abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging {

  def createObject(in: IN): OUT = {
    // We are allowed to choose codegen-only or no-codegen modes if under tests.
    val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE)
    val fallbackMode = CodegenObjectFactoryMode.withName(config)

    fallbackMode match {
      case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting =>
        createCodeGeneratedObject(in)
      case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting =>
        createInterpretedObject(in)
      case _ =>
        try {
          createCodeGeneratedObject(in)
        } catch {
          case NonFatal(_) =>
            // We should have already seen the error message in `CodeGenerator`
            logWarning("Expr codegen error and falling back to interpreter mode")
            createInterpretedObject(in)
        }
    }
  }

  protected def createCodeGeneratedObject(in: IN): OUT
  protected def createInterpretedObject(in: IN): OUT
} 
Example 104
Source File: package.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.reflect.internal.util.AbstractFileClassLoader

import org.apache.spark.sql.catalyst.rules
import org.apache.spark.util.Utils


  object DumpByteCode {
    import scala.sys.process._
    val dumpDirectory = Utils.createTempDir()
    dumpDirectory.mkdir()

    def apply(obj: Any): Unit = {
      val generatedClass = obj.getClass
      val classLoader =
        generatedClass
          .getClassLoader
          .asInstanceOf[AbstractFileClassLoader]
      val generatedBytes = classLoader.classBytes(generatedClass.getName)

      val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName)
      if (!packageDir.exists()) { packageDir.mkdir() }

      val classFile =
        new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class")

      val outfile = new java.io.FileOutputStream(classFile)
      outfile.write(generatedBytes)
      outfile.close()

      // scalastyle:off println
      println(
        s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!)
      // scalastyle:on println
    }
  }
} 
Example 105
Source File: OuterScopes.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.encoders

import java.util.concurrent.ConcurrentMap

import com.google.common.collect.MapMaker

import org.apache.spark.util.Utils

object OuterScopes {
  @transient
  lazy val outerScopes: ConcurrentMap[String, AnyRef] =
    new MapMaker().weakValues().makeMap()

  
  def getOuterScope(innerCls: Class[_]): () => AnyRef = {
    assert(innerCls.isMemberClass)
    val outerClassName = innerCls.getDeclaringClass.getName
    val outer = outerScopes.get(outerClassName)
    if (outer == null) {
      outerClassName match {
        // If the outer class is generated by REPL, users don't need to register it as it has
        // only one instance and there is a way to retrieve it: get the `$read` object, call the
        // `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()`
        // method multiply times to get the single instance of the inner most `$iw` class.
        case REPLClass(baseClassName) =>
          () => {
            val objClass = Utils.classForName(baseClassName + "$")
            val objInstance = objClass.getField("MODULE$").get(null)
            val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance)
            val baseClass = Utils.classForName(baseClassName)

            var getter = iwGetter(baseClass)
            var obj = baseInstance
            while (getter != null) {
              obj = getter.invoke(obj)
              getter = iwGetter(getter.getReturnType)
            }

            if (obj == null) {
              throw new RuntimeException(s"Failed to get outer pointer for ${innerCls.getName}")
            }

            outerScopes.putIfAbsent(outerClassName, obj)
            obj
          }
        case _ => null
      }
    } else {
      () => outer
    }
  }

  private def iwGetter(cls: Class[_]) = {
    try {
      cls.getMethod("$iw")
    } catch {
      case _: NoSuchMethodException => null
    }
  }

  // The format of REPL generated wrapper class's name, e.g. `$line12.$read$$iw$$iw`
  private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r
} 
Example 106
Source File: RuleExecutor.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.rules

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.util.Utils

object RuleExecutor {
  protected val queryExecutionMeter = QueryExecutionMetering()

  
  def execute(plan: TreeType): TreeType = {
    var curPlan = plan
    val queryExecutionMetrics = RuleExecutor.queryExecutionMeter

    batches.foreach { batch =>
      val batchStartPlan = curPlan
      var iteration = 1
      var lastPlan = curPlan
      var continue = true

      // Run until fix point (or the max number of iterations as specified in the strategy.
      while (continue) {
        curPlan = batch.rules.foldLeft(curPlan) {
          case (plan, rule) =>
            val startTime = System.nanoTime()
            val result = rule(plan)
            val runTime = System.nanoTime() - startTime

            if (!result.fastEquals(plan)) {
              queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName)
              queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime)
              logTrace(
                s"""
                  |=== Applying Rule ${rule.ruleName} ===
                  |${sideBySide(plan.treeString, result.treeString).mkString("\n")}
                """.stripMargin)
            }
            queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime)
            queryExecutionMetrics.incNumExecution(rule.ruleName)

            // Run the structural integrity checker against the plan after each rule.
            if (!isPlanIntegral(result)) {
              val message = s"After applying rule ${rule.ruleName} in batch ${batch.name}, " +
                "the structural integrity of the plan is broken."
              throw new TreeNodeException(result, message, null)
            }

            result
        }
        iteration += 1
        if (iteration > batch.strategy.maxIterations) {
          // Only log if this is a rule that is supposed to run more than once.
          if (iteration != 2) {
            val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}"
            if (Utils.isTesting) {
              throw new TreeNodeException(curPlan, message, null)
            } else {
              logWarning(message)
            }
          }
          continue = false
        }

        if (curPlan.fastEquals(lastPlan)) {
          logTrace(
            s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.")
          continue = false
        }
        lastPlan = curPlan
      }

      if (!batchStartPlan.fastEquals(curPlan)) {
        logDebug(
          s"""
            |=== Result of Batch ${batch.name} ===
            |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")}
          """.stripMargin)
      } else {
        logTrace(s"Batch ${batch.name} has no effect.")
      }
    }

    curPlan
  }
} 
Example 107
Source File: CompressionCodecs.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.Locale

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress._

import org.apache.spark.util.Utils

object CompressionCodecs {
  private val shortCompressionCodecNames = Map(
    "none" -> null,
    "uncompressed" -> null,
    "bzip2" -> classOf[BZip2Codec].getName,
    "deflate" -> classOf[DeflateCodec].getName,
    "gzip" -> classOf[GzipCodec].getName,
    "lz4" -> classOf[Lz4Codec].getName,
    "snappy" -> classOf[SnappyCodec].getName)

  
  def setCodecConfiguration(conf: Configuration, codec: String): Unit = {
    if (codec != null) {
      conf.set("mapreduce.output.fileoutputformat.compress", "true")
      conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
      conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
      conf.set("mapreduce.map.output.compress", "true")
      conf.set("mapreduce.map.output.compress.codec", codec)
    } else {
      // This infers the option `compression` is set to `uncompressed` or `none`.
      conf.set("mapreduce.output.fileoutputformat.compress", "false")
      conf.set("mapreduce.map.output.compress", "false")
    }
  }
} 
Example 108
Source File: LogicalRelation.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils


  override def newInstance(): LogicalRelation = {
    this.copy(output = output.map(_.newInstance()))
  }

  override def refresh(): Unit = relation match {
    case fs: HadoopFsRelation => fs.location.refresh()
    case _ =>  // Do nothing.
  }

  override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
}

object LogicalRelation {
  def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation =
    LogicalRelation(relation, relation.schema.toAttributes, None, isStreaming)

  def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation =
    LogicalRelation(relation, relation.schema.toAttributes, Some(table), false)
} 
Example 109
Source File: DataSourceV2StringFormat.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.util.Utils


  def pushedFilters: Seq[Expression]

  private def sourceName: String = source match {
    case registered: DataSourceRegister => registered.shortName()
    // source.getClass.getSimpleName can cause Malformed class name error,
    // call safer `Utils.getSimpleName` instead
    case _ => Utils.getSimpleName(source.getClass)
  }

  def metadataString: String = {
    val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)]

    if (pushedFilters.nonEmpty) {
      entries += "Filters" -> pushedFilters.mkString("[", ", ", "]")
    }

    // TODO: we should only display some standard options like path, table, etc.
    if (options.nonEmpty) {
      entries += "Options" -> Utils.redact(options).map {
        case (k, v) => s"$k=$v"
      }.mkString("[", ",", "]")
    }

    val outputStr = Utils.truncatedString(output, "[", ", ", "]")

    val entriesStr = if (entries.nonEmpty) {
      Utils.truncatedString(entries.map {
        case (key, value) => key + ": " + StringUtils.abbreviate(value, 100)
      }, " (", ", ", ")")
    } else {
      ""
    }

    s"$sourceName$outputStr$entriesStr"
  }
} 
Example 110
Source File: DriverRegistry.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Driver, DriverManager}

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


  DriverManager.getDrivers

  private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty

  def register(className: String): Unit = {
    val cls = Utils.getContextOrSparkClassLoader.loadClass(className)
    if (cls.getClassLoader == null) {
      logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required")
    } else if (wrapperMap.get(className).isDefined) {
      logTrace(s"Wrapper for $className already exists")
    } else {
      synchronized {
        if (wrapperMap.get(className).isEmpty) {
          val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver])
          DriverManager.registerDriver(wrapper)
          wrapperMap(className) = wrapper
          logTrace(s"Wrapper for $className registered")
        }
      }
    }
  }
} 
Example 111
Source File: EvalPythonExec.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import java.io.File

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils



abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
  extends SparkPlan {

  def children: Seq[SparkPlan] = child :: Nil

  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))

  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
    udf.children match {
      case Seq(u: PythonUDF) =>
        val (chained, children) = collectFunctions(u)
        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
      case children =>
        // There should not be any other UDFs, or the children can't be evaluated directly.
        assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
    }
  }

  protected def evaluate(
      funcs: Seq[ChainedPythonFunctions],
      argOffsets: Array[Array[Int]],
      iter: Iterator[InternalRow],
      schema: StructType,
      context: TaskContext): Iterator[InternalRow]

  protected override def doExecute(): RDD[InternalRow] = {
    val inputRDD = child.execute().map(_.copy())

    inputRDD.mapPartitions { iter =>
      val context = TaskContext.get()

      // The queue used to buffer input rows so we can drain it to
      // combine input with output from Python.
      val queue = HybridRowQueue(context.taskMemoryManager(),
        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
      context.addTaskCompletionListener[Unit] { ctx =>
        queue.close()
      }

      val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

      // flatten all the arguments
      val allInputs = new ArrayBuffer[Expression]
      val dataTypes = new ArrayBuffer[DataType]
      val argOffsets = inputs.map { input =>
        input.map { e =>
          if (allInputs.exists(_.semanticEquals(e))) {
            allInputs.indexWhere(_.semanticEquals(e))
          } else {
            allInputs += e
            dataTypes += e.dataType
            allInputs.length - 1
          }
        }.toArray
      }.toArray
      val projection = newMutableProjection(allInputs, child.output)
      val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
        StructField(s"_$i", dt)
      })

      // Add rows to queue to join later with the result.
      val projectedRowIter = iter.map { inputRow =>
        queue.add(inputRow.asInstanceOf[UnsafeRow])
        projection(inputRow)
      }

      val outputRowIterator = evaluate(
        pyFuncs, argOffsets, projectedRowIter, schema, context)

      val joined = new JoinedRow
      val resultProj = UnsafeProjection.create(output, output)

      outputRowIterator.map { outputRow =>
        resultProj(joined(queue.remove(), outputRow))
      }
    }
  }
} 
Example 112
Source File: ExistingRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Encoder, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

object RDDConversions {
  def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
    data.mapPartitions { iterator =>
      val numColumns = outputTypes.length
      val mutableRow = new GenericInternalRow(numColumns)
      val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
      iterator.map { r =>
        var i = 0
        while (i < numColumns) {
          mutableRow(i) = converters(i)(r.productElement(i))
          i += 1
        }

        mutableRow
      }
    }
  }

  
case class RDDScanExec(
    output: Seq[Attribute],
    rdd: RDD[InternalRow],
    name: String,
    override val outputPartitioning: Partitioning = UnknownPartitioning(0),
    override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {

  private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("")

  override val nodeName: String = s"Scan $name$rddName"

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
      val proj = UnsafeProjection.create(schema)
      proj.initialize(index)
      iter.map { r =>
        numOutputRows += 1
        proj(r)
      }
    }
  }

  override def simpleString: String = {
    s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}"
  }
} 
Example 113
Source File: FileStreamOptions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.util.Try

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.util.Utils


  val fileNameOnly: Boolean = withBooleanParameter("fileNameOnly", false)

  private def withBooleanParameter(name: String, default: Boolean) = {
    parameters.get(name).map { str =>
      try {
        str.toBoolean
      } catch {
        case _: IllegalArgumentException =>
          throw new IllegalArgumentException(
            s"Invalid value '$str' for option '$name', must be 'true' or 'false'")
      }
    }.getOrElse(default)
  }
} 
Example 114
Source File: ContinuousWriteRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
import org.apache.spark.util.Utils


class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow])
    extends RDD[Unit](prev) {

  override val partitioner = prev.partitioner

  override def getPartitions: Array[Partition] = prev.partitions

  override def compute(split: Partition, context: TaskContext): Iterator[Unit] = {
    val epochCoordinator = EpochCoordinatorRef.get(
      context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
      SparkEnv.get)
    EpochTracker.initializeCurrentEpoch(
      context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
    while (!context.isInterrupted() && !context.isCompleted()) {
      var dataWriter: DataWriter[InternalRow] = null
      // write the data and commit this writer.
      Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
        try {
          val dataIterator = prev.compute(split, context)
          dataWriter = writeTask.createDataWriter(
            context.partitionId(),
            context.taskAttemptId(),
            EpochTracker.getCurrentEpoch.get)
          while (dataIterator.hasNext) {
            dataWriter.write(dataIterator.next())
          }
          logInfo(s"Writer for partition ${context.partitionId()} " +
            s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
          val msg = dataWriter.commit()
          epochCoordinator.send(
            CommitPartitionEpoch(
              context.partitionId(),
              EpochTracker.getCurrentEpoch.get,
              msg)
          )
          logInfo(s"Writer for partition ${context.partitionId()} " +
            s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.")
          EpochTracker.incrementCurrentEpoch()
        } catch {
          case _: InterruptedException =>
          // Continuous shutdown always involves an interrupt. Just finish the task.
        }
      })(catchBlock = {
        // If there is an error, abort this writer. We enter this callback in the middle of
        // rethrowing an exception, so compute() will stop executing at this point.
        logError(s"Writer for partition ${context.partitionId()} is aborting.")
        if (dataWriter != null) dataWriter.abort()
        logError(s"Writer for partition ${context.partitionId()} aborted.")
      })
    }

    Iterator()
  }

  override def clearDependencies() {
    super.clearDependencies()
    prev = null
  }
} 
Example 115
Source File: BenchmarkQueryTest.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with BeforeAndAfterAll {

  // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting
  // the max iteration of analyzer/optimizer batches.
  assert(Utils.isTesting, "spark.testing is not set to true")

  
  protected override def afterAll(): Unit = {
    try {
      // For debugging dump some statistics about how much time was spent in various optimizer rules
      logWarning(RuleExecutor.dumpTimeSpent())
      spark.sessionState.catalog.reset()
    } finally {
      super.afterAll()
    }
  }

  override def beforeAll() {
    super.beforeAll()
    RuleExecutor.resetMetrics()
  }

  protected def checkGeneratedCode(plan: SparkPlan): Unit = {
    val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]()
    plan foreach {
      case s: WholeStageCodegenExec =>
        codegenSubtrees += s
      case _ =>
    }
    codegenSubtrees.toSeq.foreach { subtree =>
      val code = subtree.doCodeGen()._2
      try {
        // Just check the generated code can be properly compiled
        CodeGenerator.compile(code)
      } catch {
        case e: Exception =>
          val msg =
            s"""
               |failed to compile:
               |Subtree:
               |$subtree
               |Generated code:
               |${CodeFormatter.format(code)}
             """.stripMargin
          throw new Exception(msg, e)
      }
    }
  }
} 
Example 116
Source File: RowDataSourceStrategySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE OR REPLACE TEMPORARY VIEW inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
       """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 117
Source File: AlarmFactory.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.alarm

import java.util.ServiceLoader

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
import org.apache.spark.util.Utils

object AlarmFactory {
  def create(alarmName: String, options: Map[String, String]): Alarm = {
    val loader = Utils.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[Alarm], loader)
    val AlarmClass =
      serviceLoader.asScala.filter(_.name.equalsIgnoreCase(alarmName)).toList match {
        case head :: Nil =>
          head.getClass
        case _ =>
          throw new SparkException("error when instantiate spark.xsql.alarm.items")
      }
    AlarmClass.newInstance().bind(options)
  }

} 
Example 118
Source File: BarChartPainter.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.painter

import java.io.File
import java.util.Scanner

import org.jfree.chart.{ChartFactory, ChartUtils}
import org.jfree.chart.plot.PlotOrientation
import org.jfree.data.category.DefaultCategoryDataset

import org.apache.spark.util.Utils

class BarChartPainter(dataPath: String, picturePath: String)
  extends Painter(dataPath, picturePath) {

  def createDataset(): DefaultCategoryDataset = {
    fw.flush()
    fw.close()
    val dataset = new DefaultCategoryDataset
    val scaner = new Scanner(new File(dataPath))
    while (scaner.hasNext()) {
      val cols = scaner.next().split(",")
      dataset.addValue(Utils.byteStringAsMb(cols(1) + "b"), "peak", cols(0))
      dataset.addValue(Utils.byteStringAsMb(cols(2) + "b"), "majority", cols(0))
    }
    dataset
  }

  def paint(
      width: Int,
      height: Int,
      chartTitle: String,
      categoryAxisLabel: String,
      valueAxisLabel: String,
      yLB: Double,
      yUB: Double): Unit = {
    val barChart = ChartFactory.createBarChart(
      chartTitle,
      categoryAxisLabel,
      valueAxisLabel,
      createDataset,
      PlotOrientation.VERTICAL,
      true,
      false,
      false)
    barChart.getCategoryPlot.getRangeAxis.setRange(yLB, yUB)
    ChartUtils.saveChartAsJPEG(new File(picturePath), barChart, width, height)
  }

  override def paint(
      width: Int,
      height: Int,
      chartTitle: String,
      categoryAxisLabel: String,
      valueAxisLabel: String): Unit = {}
} 
Example 119
Source File: MonitorFactory.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.monitor

import java.util.ServiceLoader

import scala.collection.JavaConverters._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.alarm.Alarm
import org.apache.spark.util.Utils
import org.apache.spark.util.kvstore.KVStore

object MonitorFactory {

  def create(
      monitorName: String,
      alarms: Seq[Alarm],
      appStore: KVStore,
      conf: SparkConf): Monitor = {
    val loader = Utils.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[Monitor], loader)
    val MonitorClass = serviceLoader.asScala
      .filter(_.item.equals(MonitorItem.withName(monitorName)))
      .toList match {
      case head :: Nil =>
        head.getClass
      case _ =>
        throw new SparkException("error when instantiate spark.xsql.monitor.items")
    }
    MonitorClass.newInstance().bind(alarms).bind(appStore).bind(conf)
  }
} 
Example 120
Source File: SQLContextExtensionBase.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.extension

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{ParserDialect, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, SimpleFunctionRegistry}
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.DDLParser
import org.apache.spark.sql.extension.OptimizerFactory.ExtendableOptimizerBatch
import org.apache.spark.util.Utils

import scala.util.Try
import scala.util.control.NonFatal


  override protected def extendedParserDialect: ParserDialect =
    try {
      val clazz = Utils.classForName(dialectClassName)
      clazz.newInstance().asInstanceOf[ParserDialect]
    } catch {
      case NonFatal(e) =>
        // Since we didn't find the available SQL Dialect, it will fail even for SET command:
        // SET spark.sql.dialect=sql; Let's reset as default dialect automatically.
        val dialect = conf.dialect
        // reset the sql dialect
        conf.unsetConf(SQLConf.DIALECT)
        // throw out the exception, and the default sql dialect will take effect for next query.
        throw new DialectException(
          s"""
              |Instantiating dialect '$dialect' failed.
              |Reverting to default dialect '${conf.dialect}'""".stripMargin, e)
    }

  // (suggestion) make this implicit to FunctionRegistry.
  protected def registerBuiltins(registry: FunctionRegistry): Unit = {
    FunctionRegistry.expressions.foreach {
      case (name, (info, builder)) => registry.registerFunction(name, builder)
    }
  }

  override protected def extendedDdlParser(parser: String => LogicalPlan): DDLParser =
    new DDLParser(sqlParser.parse(_))

  override protected def registerFunctions(registry: FunctionRegistry): Unit = { }

} 
Example 121
Source File: SapSQLEnv.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.sap.thriftserver

import java.io.PrintStream

import org.apache.spark.scheduler.StatsReportListener
import org.apache.spark.sql.hive.{HiveContext, SapHiveContext}
import org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver
import org.apache.spark.sql.hive.thriftserver.SparkSQLEnv._
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkConf, SparkContext}

import scala.collection.JavaConversions._


object SapSQLEnv extends Logging {

  def init() {
    logDebug("Initializing SapSQLEnv")
    if (hiveContext == null) {
      logInfo("Creating SapSQLContext")
      val sparkConf = new SparkConf(loadDefaults = true)
      val maybeSerializer = sparkConf.getOption("spark.serializer")
      val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking")
      // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of
      // the default appName [SparkSQLCLIDriver] in cli or beeline.
      val maybeAppName = sparkConf
        .getOption("spark.app.name")
        .filterNot(_ == classOf[SparkSQLCLIDriver].getName)

      sparkConf
        .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}"))
        .set("spark.serializer",
          maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer"))
        .set("spark.kryo.referenceTracking",
          maybeKryoReferenceTracking.getOrElse("false"))

      sparkContext = new SparkContext(sparkConf)
      sparkContext.addSparkListener(new StatsReportListener())
      hiveContext = new SapHiveContext(sparkContext)

      hiveContext.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8"))
      hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8"))
      hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8"))

      hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion)

      if (log.isDebugEnabled) {
        hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) =>
          logDebug(s"HiveConf var: $k=$v")
        }
      }
    }
  }
} 
Example 122
Source File: MCLModelSuite.scala    From MCL_spark   with MIT License 5 votes vote down vote up
package org.apache.spark.mllib.clustering

import org.apache.log4j.{Level, Logger}
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.util.Utils


class MCLModelSuite extends MCLFunSuite{
  // Disable Spark messages when running program
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  test("model save/load", UnitTest){

    val users: RDD[(VertexId, String)] =
      sc.parallelize(Array((0L,"Node1"), (1L,"Node2"),
        (2L,"Node3"), (3L,"Node4"),(4L,"Node5"),
        (5L,"Node6"), (6L,"Node7"), (7L, "Node8"),
        (8L, "Node9"), (9L, "Node10"), (10L, "Node11")))

    val relationships: RDD[Edge[Double]] =
      sc.parallelize(
        Seq(Edge(0, 1, 1.0), Edge(1, 0, 1.0),
          Edge(0, 2, 1.0), Edge(2, 0, 1.0),
          Edge(0, 3, 1.0), Edge(3, 0, 1.0),
          Edge(1, 2, 1.0), Edge(2, 1, 1.0),
          Edge(1, 3, 1.0), Edge(3, 1, 1.0),
          Edge(2, 3, 1.0), Edge(3, 2, 1.0),
          Edge(4, 5, 1.0), Edge(5, 4, 1.0),
          Edge(4, 6, 1.0), Edge(6, 4, 1.0),
          Edge(4, 7, 1.0), Edge(7, 4, 1.0),
          Edge(5, 6, 1.0), Edge(6, 5, 1.0),
          Edge(5, 7, 1.0), Edge(7, 5, 1.0),
          Edge(6, 7, 1.0), Edge(7, 6, 1.0),
          Edge(3, 8, 1.0), Edge(8, 3, 1.0),
          Edge(9, 8, 1.0), Edge(8, 9, 1.0),
          Edge(9, 10, 1.0), Edge(10, 9, 1.0),
          Edge(4, 10, 1.0), Edge(10, 4, 1.0)
        ))

    val graph = Graph(users, relationships)

    val model: MCLModel = MCL.train(graph)

    // Check number of clusters
    model.nbClusters shouldEqual 3

    // Check save and load methods
    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString

    Array(true, false).foreach { case selector =>
      // Save model, load it back, and compare.
      try {
        model.save(sc, path)
        val sameModel = MCLModel.load(sc, path)
        assertDatasetEquals(model.assignments.orderBy("id"), sameModel.assignments.orderBy("id"))
      } finally {
        Utils.deleteRecursively(tempDir)
      }
    }

  }

  test("nodes assignments", UnitTest) {
    val nodeId = 1.0.toLong
    val cluster = 2.0.toLong
    val newAssignment:Assignment = Assignment.apply(Row(nodeId, cluster))

    newAssignment.id shouldEqual nodeId
    newAssignment.cluster shouldEqual cluster
  }

} 
Example 123
Source File: EventHubsWriter.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.eventhubs

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{ AnalysisException, SparkSession }
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{ BinaryType, StringType }
import org.apache.spark.util.Utils


private[eventhubs] object EventHubsWriter extends Logging {

  val BodyAttributeName = "body"
  val PartitionKeyAttributeName = "partitionKey"
  val PartitionIdAttributeName = "partition"
  val PropertiesAttributeName = "properties"

  override def toString: String = "EventHubsWriter"

  private def validateQuery(schema: Seq[Attribute], parameters: Map[String, String]): Unit = {
    schema
      .find(_.name == BodyAttributeName)
      .getOrElse(
        throw new AnalysisException(s"Required attribute '$BodyAttributeName' not found.")
      )
      .dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(
          s"$BodyAttributeName attribute type " +
            s"must be a String or BinaryType.")
    }
  }

  def write(
      sparkSession: SparkSession,
      queryExecution: QueryExecution,
      parameters: Map[String, String]
  ): Unit = {
    val schema = queryExecution.analyzed.output
    validateQuery(schema, parameters)
    queryExecution.toRdd.foreachPartition { iter =>
      val writeTask = new EventHubsWriteTask(parameters, schema)
      Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
        finallyBlock = writeTask.close()
      )
    }
  }
} 
Example 124
Source File: OapEnv.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.PrintStream

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
import org.apache.spark.sql.oap.listener.OapListener
import org.apache.spark.sql.oap.ui.OapTab
import org.apache.spark.util.Utils


private[spark] object OapEnv extends Logging {
  logDebug("Initializing Oap Env")

  var initialized: Boolean = false
  var sparkSession: SparkSession = _

  // This is to enable certain OAP features, like UI, even
  // in non-Spark SQL CLI/ThriftServer conditions
  def initWithoutCreatingSparkSession(): Unit = synchronized {
    if (!initialized && !Utils.isTesting) {
      val sc = SparkContext.getOrCreate()
      sc.addSparkListener(new OapListener)
      this.sparkSession = SparkSession.getActiveSession.get
      sc.ui.foreach(new OapTab(_))
      initialized = true
    }
  }
} 
Example 125
Source File: OapRpcManagerSlave.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.oap.rpc

import java.util.concurrent.TimeUnit

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.execution.datasources.oap.filecache.{CacheStats, FiberCacheManager}
import org.apache.spark.sql.internal.oap.OapConf
import org.apache.spark.sql.oap.adapter.RpcEndpointRefAdapter
import org.apache.spark.sql.oap.rpc.OapMessages._
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.{ThreadUtils, Utils}


private[spark] class OapRpcManagerSlave(
    rpcEnv: RpcEnv,
    val driverEndpoint: RpcEndpointRef,
    executorId: String,
    blockManager: BlockManager,
    fiberCacheManager: FiberCacheManager,
    conf: SparkConf) extends OapRpcManager {

  // Send OapHeartbeatMessage to Driver timed
  private val oapHeartbeater =
    ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater")

  private val slaveEndpoint = rpcEnv.setupEndpoint(
    s"OapRpcManagerSlave_$executorId", new OapRpcManagerSlaveEndpoint(rpcEnv, fiberCacheManager))

  initialize()
  startOapHeartbeater()

  protected def heartbeatMessages: Array[() => Heartbeat] = {
    Array(
      () => FiberCacheHeartbeat(
        executorId, blockManager.blockManagerId, fiberCacheManager.status()),
      () => FiberCacheMetricsHeartbeat(executorId, blockManager.blockManagerId,
        CacheStats.status(fiberCacheManager.cacheStats, conf)))
  }

  private def initialize() = {
    RpcEndpointRefAdapter.askSync[Boolean](
      driverEndpoint, RegisterOapRpcManager(executorId, slaveEndpoint))
  }

  override private[spark] def send(message: OapMessage): Unit = {
    driverEndpoint.send(message)
  }

  private[sql] def startOapHeartbeater(): Unit = {

    def reportHeartbeat(): Unit = {
      // OapRpcManagerSlave is created in SparkEnv. Before we start the heartbeat, we need make
      // sure the SparkEnv has been created and the block manager has been initialized. We check
      // blockManagerId as it will be set after initialization.
      if (blockManager.blockManagerId != null) {
        heartbeatMessages.map(_.apply()).foreach(send)
      }
    }

    val intervalMs = conf.getTimeAsMs(
      OapConf.OAP_HEARTBEAT_INTERVAL.key, OapConf.OAP_HEARTBEAT_INTERVAL.defaultValue.get)

    // Wait a random interval so the heartbeats don't end up in sync
    val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int]

    val heartbeatTask = new Runnable() {
      override def run(): Unit = Utils.logUncaughtExceptions(reportHeartbeat())
    }
    oapHeartbeater.scheduleAtFixedRate(
      heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
  }

  override private[spark] def stop(): Unit = {
    oapHeartbeater.shutdown()
  }
}

private[spark] class OapRpcManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv, fiberCacheManager: FiberCacheManager)
  extends ThreadSafeRpcEndpoint with Logging {

  override def receive: PartialFunction[Any, Unit] = {
    case message: OapMessage => handleOapMessage(message)
    case _ =>
  }

  private def handleOapMessage(message: OapMessage): Unit = message match {
    case CacheDrop(indexName) => fiberCacheManager.releaseIndexCache(indexName)
    case _ =>
  }
} 
Example 126
Source File: OapBitmapWrappedFiberCacheSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.utils

import java.io.{ByteArrayOutputStream, DataOutputStream, FileOutputStream}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.roaringbitmap.RoaringBitmap

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.datasources.OapException
import org.apache.spark.sql.execution.datasources.oap.filecache.{BitmapFiberId, FiberCache}
import org.apache.spark.sql.oap.OapRuntime
import org.apache.spark.sql.test.oap.SharedOapContext
import org.apache.spark.util.Utils

// Below are used to test the functionality of OapBitmapWrappedFiberCache class.
class OapBitmapWrappedFiberCacheSuite
  extends QueryTest with SharedOapContext {

  private def loadRbFile(fin: FSDataInputStream, offset: Long, size: Int): FiberCache =
    OapRuntime.getOrCreate.fiberCacheManager.toIndexFiberCache(fin, offset, size)

  test("test the functionality of OapBitmapWrappedFiberCache class") {
    val CHUNK_SIZE = 1 << 16
    val dataForRunChunk = (1 to 9).toSeq
    val dataForArrayChunk = Seq(1, 3, 5, 7, 9)
    val dataForBitmapChunk = (1 to 10000).filter(_ % 2 == 1)
    val dataCombination =
      dataForBitmapChunk ++ dataForArrayChunk ++ dataForRunChunk
    val dataArray =
      Array(dataForRunChunk, dataForArrayChunk, dataForBitmapChunk, dataCombination)
    dataArray.foreach(dataIdx => {
      val dir = Utils.createTempDir()
      val rb = new RoaringBitmap()
      dataIdx.foreach(rb.add)
      val rbFile = dir.getAbsolutePath + "rb.bin"
      rb.runOptimize()
      val rbFos = new FileOutputStream(rbFile)
      val rbBos = new ByteArrayOutputStream()
      val rbDos = new DataOutputStream(rbBos)
      rb.serialize(rbDos)
      rbBos.writeTo(rbFos)
      rbBos.close()
      rbDos.close()
      rbFos.close()
      val rbPath = new Path(rbFile.toString)
      val conf = new Configuration()
      val fin = rbPath.getFileSystem(conf).open(rbPath)
      val rbFileSize = rbPath.getFileSystem(conf).getFileStatus(rbPath).getLen
      val rbFiber = BitmapFiberId(
        () => loadRbFile(fin, 0L, rbFileSize.toInt), rbPath.toString, 0, 0)
      val rbWfc = new OapBitmapWrappedFiberCache(
        OapRuntime.getOrCreate.fiberCacheManager.get(rbFiber))
      rbWfc.init
      val chunkLength = rbWfc.getTotalChunkLength
      val length = dataIdx.size / CHUNK_SIZE
      assert(chunkLength == (length + 1))
      val chunkKeys = rbWfc.getChunkKeys
      assert(chunkKeys(0).toInt == 0)
      rbWfc.setOffset(0)
      val chunk = rbWfc.getIteratorForChunk(0)
      chunk match {
        case RunChunkIterator(rbWfc) => assert(chunk == RunChunkIterator(rbWfc))
        case ArrayChunkIterator(rbWfc, 0) => assert(chunk == ArrayChunkIterator(rbWfc, 0))
        case BitmapChunkIterator(rbWfc) => assert(chunk == BitmapChunkIterator(rbWfc))
        case _ => throw new OapException("unexpected chunk in OapBitmapWrappedFiberCache.")
      }
      rbWfc.release
      fin.close
      dir.delete
    })
  }
} 
Example 127
Source File: BitmapAnalyzeStatisticsSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.index

import org.apache.hadoop.fs.RawLocalFileSystem
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.oap.SharedOapContext
import org.apache.spark.util.Utils


trait SharedOapContextWithRawLocalFileSystem extends SharedOapContext {
  oapSparkConf.set("spark.hadoop.fs.file.impl", classOf[RawLocalFileSystem].getName)
}

class BitmapAnalyzeStatisticsSuite extends QueryTest with SharedOapContextWithRawLocalFileSystem
    with BeforeAndAfterEach {
  import testImplicits._

  override def beforeEach(): Unit = {
    val tempDir = Utils.createTempDir()
    val path = tempDir.getAbsolutePath
    sql(s"""CREATE TEMPORARY VIEW oap_test (a INT, b STRING)
            | USING oap
            | OPTIONS (path '$path')""".stripMargin)
  }

  override def afterEach(): Unit = {
    sqlContext.dropTempTable("oap_test")
  }

  test("Bitmap index typical equal test") {
    val data: Seq[(Int, String)] = (1 to 200).map { i => (i, s"this is test $i") }
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test select * from t")
    sql("create oindex idxa on oap_test (a) USING BITMAP")
    checkAnswer(sql(s"SELECT * FROM oap_test WHERE a = 20 OR a = 21"),
      Row(20, "this is test 20") :: Row(21, "this is test 21") :: Nil)
    sql("drop oindex idxa on oap_test")
  }
} 
Example 128
Source File: OapIndexCommitProtocolSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.index

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.MRJobConfig
import org.apache.hadoop.mapreduce.TaskAttemptID
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.task.JobContextImpl
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl

import org.apache.spark.sql.test.oap.SharedOapContext
import org.apache.spark.util.Utils

class OapIndexCommitProtocolSuite extends SharedOapContext {
  test("newTaskTempFile") {
    val attempt = "attempt_200707121733_0001_m_000000_0"
    val taskID = TaskAttemptID.forName(attempt)
    val jobID = taskID.getJobID.toString
    val outDir = Utils.createTempDir().getAbsolutePath
    val job = Job.getInstance()
    FileOutputFormat.setOutputPath(job, new Path(outDir))
    val conf = job.getConfiguration()
    conf.set(MRJobConfig.TASK_ATTEMPT_ID, attempt)
    val jobContext = new JobContextImpl(conf, taskID.getJobID())
    val taskContext = new TaskAttemptContextImpl(conf, taskID)
    val commitProtocol = new OapIndexCommitProtocol(jobID, outDir)
    // test task temp path
    val pendingDirName = "_temporary_" + jobID
    commitProtocol.setupJob(jobContext)
    commitProtocol.setupTask(taskContext)
    val tempFile = new Path(commitProtocol.newTaskTempFile(taskContext, None, "test"))
    val expectedJobAttemptPath = new Path(new Path(outDir, pendingDirName), "0")
    val expectedTaskWorkPath = new Path(new Path(expectedJobAttemptPath, pendingDirName), attempt)
    assert(tempFile.getParent == expectedTaskWorkPath)
  }
} 
Example 129
Source File: ClusteredFilterSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.cluster

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.oap.SharedOapLocalClusterContext
import org.apache.spark.util.Utils

class ClusteredFilterSuite
  extends QueryTest with SharedOapLocalClusterContext with BeforeAndAfterEach {

  import testImplicits._

  private var currentPath: String = _

  override def beforeEach(): Unit = {
    val path = Utils.createTempDir().getAbsolutePath
    currentPath = path
    sql(s"""CREATE TEMPORARY VIEW oap_test (a INT, b STRING)
           | USING oap
           | OPTIONS (path '$path')""".stripMargin)
  }

  override def afterEach(): Unit = {
    sqlContext.dropTempTable("oap_test")
  }

  test("filtering") {
    val data: Seq[(Int, String)] = (1 to 300).map { i => (i, s"this is test $i") }
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test select * from t")
    sql("create oindex index1 on oap_test (a)")

    checkAnswer(sql("SELECT * FROM oap_test WHERE a = 1"),
      Row(1, "this is test 1") :: Nil)

    checkAnswer(sql("SELECT * FROM oap_test WHERE a > 1 AND a <= 3"),
      Row(2, "this is test 2") :: Row(3, "this is test 3") :: Nil)

    checkAnswer(sql("SELECT * FROM oap_test WHERE a <= 2"),
      Row(1, "this is test 1") :: Row(2, "this is test 2") :: Nil)

    checkAnswer(sql("SELECT * FROM oap_test WHERE a >= 300"),
      Row(300, "this is test 300") :: Nil)

    sql("drop oindex index1 on oap_test")
  }

} 
Example 130
Source File: DataFileSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.io

import org.apache.hadoop.conf.Configuration

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.datasources.OapException
import org.apache.spark.sql.execution.datasources.oap.OapFileFormat
import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase
import org.apache.spark.sql.test.oap.SharedOapContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

class DataFileSuite extends QueryTest with SharedOapContext {

  override def beforeEach(): Unit = {
    val path = Utils.createTempDir().getAbsolutePath
  }

  // Override afterEach because OapDataFile will open a InputStream for OapDataFileMeta
  // but no method to manual close it and we can not to check open streams.
  override def afterEach(): Unit = {}

  test("apply and cache") {
    val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString))
    val schema = new StructType()
    val config = new Configuration()

    withTempPath { dir =>
      val df = spark.createDataFrame(data)
      df.repartition(1).write.format("oap").save(dir.getAbsolutePath)
      val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0)
      val datafile =
        DataFile(file, schema, OapFileFormat.OAP_DATA_FILE_V1_CLASSNAME, config)
      assert(datafile.path == file)
      assert(datafile.schema == schema)
      assert(datafile.configuration == config)
    }

    withTempPath { dir =>
      val df = spark.createDataFrame(data)
      df.repartition(1).write.parquet(dir.getAbsolutePath)
      val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0)
      val datafile =
        DataFile(file, schema, OapFileFormat.PARQUET_DATA_FILE_CLASSNAME, config)
      assert(datafile.path == file)
      assert(datafile.schema == schema)
      assert(datafile.configuration == config)
    }

    withTempPath { dir =>
      val df = spark.createDataFrame(data)
      df.repartition(1).write.format("orc").save(dir.getAbsolutePath)
      val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0)
      val datafile =
        DataFile(file, schema, OapFileFormat.ORC_DATA_FILE_CLASSNAME, config)
      assert(datafile.path == file)
      assert(datafile.schema == schema)
      assert(datafile.configuration == config)
    }

    // DataFile object is global. After OrcDataFile is added, then need to change to 3 if
    // we run the whole tests.
    assert(DataFile.cachedConstructorCount == 3)

    intercept[OapException] {
      DataFile("nofile", schema, "NotExistClass", config)
      assert(DataFile.cachedConstructorCount == 2)
    }
  }

  test("DataFile equals") {
    val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString))
    val schema = new StructType()
    val config = new Configuration()
    withTempPath { dir =>
      val df = spark.createDataFrame(data)
      df.repartition(1).write.parquet(dir.getAbsolutePath)
      val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0)
      val datafile1 =
        DataFile(file, schema, OapFileFormat.PARQUET_DATA_FILE_CLASSNAME, config)
      val datafile2 =
        DataFile(file, schema, OapFileFormat.PARQUET_DATA_FILE_CLASSNAME, config)
      assert(datafile1.equals(datafile2))
      assert(datafile1.hashCode() == datafile2.hashCode())
    }

    withTempPath { dir =>
      val df = spark.createDataFrame(data)
      df.repartition(1).write.format("oap").save(dir.getAbsolutePath)
      val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0)
      val datafile1 =
        DataFile(file, schema, OapFileFormat.OAP_DATA_FILE_V1_CLASSNAME, config)
      val datafile2 =
        DataFile(file, schema, OapFileFormat.OAP_DATA_FILE_V1_CLASSNAME, config)
      assert(datafile1.equals(datafile2))
      assert(datafile1.hashCode() == datafile2.hashCode())
    }
  }
} 
Example 131
Source File: FileSkipSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.io

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.oap.SharedOapContext
import org.apache.spark.util.Utils

class FileSkipSuite extends QueryTest with SharedOapContext with BeforeAndAfterEach {
  import testImplicits._

  override def beforeEach(): Unit = {
    val path1 = Utils.createTempDir().getAbsolutePath

    sql(s"""CREATE TEMPORARY VIEW oap_test_1 (a INT, b STRING)
           | USING oap
           | OPTIONS (path '$path1')""".stripMargin)
  }

  override def afterEach(): Unit = {
    sqlContext.dropTempTable("oap_test_1")
  }

  test("skip all file (is not null)") {
    val data: Seq[(Int, String)] =
      scala.util.Random.shuffle(1 to 300).map(i => (i, null)).toSeq
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test_1 select * from t")
    val result = sql("SELECT * FROM oap_test_1 WHERE b is not null")
    assert(result.count == 0)
  }

  test("skip all file (equal)") {
    val data: Seq[(Int, String)] =
      scala.util.Random.shuffle(1 to 300).map(i => (i, s"this is test $i")).toSeq
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test_1 select * from t")
    val result1 = sql("SELECT * FROM oap_test_1 WHERE a = 1")
    assert(result1.count == 1)
    val result2 = sql("SELECT * FROM oap_test_1 WHERE a = 500")
    assert(result2.count == 0)
  }

  test("skip all file (lt)") {
    val data: Seq[(Int, String)] =
      scala.util.Random.shuffle(1 to 300).map(i => (i, s"this is test $i")).toSeq
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test_1 select * from t")
    val result1 = sql("SELECT * FROM oap_test_1 WHERE a < 1")
    assert(result1.count == 0)
    val result2 = sql("SELECT * FROM oap_test_1 WHERE a < 2")
    assert(result2.count == 1)
  }

  test("skip all file (lteq)") {
    val data: Seq[(Int, String)] =
      scala.util.Random.shuffle(1 to 300).map(i => (i, s"this is test $i")).toSeq
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test_1 select * from t")
    val result1 = sql("SELECT * FROM oap_test_1 WHERE a <= 0")
    assert(result1.count == 0)
    val result2 = sql("SELECT * FROM oap_test_1 WHERE a <= 1")
    assert(result2.count == 1)
  }

  test("skip all file (gt)") {
    val data: Seq[(Int, String)] =
      scala.util.Random.shuffle(1 to 300).map(i => (i, s"this is test $i")).toSeq
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test_1 select * from t")
    val result1 = sql("SELECT * FROM oap_test_1 WHERE a > 300")
    assert(result1.count == 0)
    val result2 = sql("SELECT * FROM oap_test_1 WHERE a > 2")
    assert(result2.count == 298)
  }

  test("skip all file (gteq)") {
    val data: Seq[(Int, String)] =
      scala.util.Random.shuffle(1 to 300).map(i => (i, s"this is test $i")).toSeq
    data.toDF("key", "value").createOrReplaceTempView("t")
    sql("insert overwrite table oap_test_1 select * from t")
    val result1 = sql("SELECT * FROM oap_test_1 WHERE a >= 300")
    assert(result1.count == 1)
    val result2 = sql("SELECT * FROM oap_test_1 WHERE a >= 500")
    assert(result2.count == 0)
  }
} 
Example 132
Source File: YarnClusterSchedulerBackend.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.conf.YarnConfiguration

import org.apache.spark.SparkContext
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils

private[spark] class YarnClusterSchedulerBackend(
    scheduler: TaskSchedulerImpl,
    sc: SparkContext)
  extends YarnSchedulerBackend(scheduler, sc) {

  override def start() {
    super.start()
    totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf)
  }

  override def applicationId(): String =
    // In YARN Cluster mode, the application ID is expected to be set, so log an error if it's
    // not found.
  //在YARN群集模式下,应该设置应用程序ID,因此如果找不到则记录错误
    sc.getConf.getOption("spark.yarn.app.id").getOrElse {
      logError("Application ID is not set.")
      super.applicationId
    }

  override def applicationAttemptId(): Option[String] =
    // In YARN Cluster mode, the attempt ID is expected to be set, so log an error if it's
    // not found.
  //在YARN群集模式下,预计会设置尝试ID,因此如果找不到则会记录错误。
    sc.getConf.getOption("spark.yarn.app.attemptId").orElse {
      logError("Application attempt ID is not set.")
      super.applicationAttemptId
    }

  override def getDriverLogUrls: Option[Map[String, String]] = {
    var driverLogs: Option[Map[String, String]] = None
    try {
      val yarnConf = new YarnConfiguration(sc.hadoopConfiguration)
      val containerId = YarnSparkHadoopUtil.get.getContainerId

      val httpAddress = System.getenv(Environment.NM_HOST.name()) +
        ":" + System.getenv(Environment.NM_HTTP_PORT.name())
      // lookup appropriate http scheme for container log urls
      val yarnHttpPolicy = yarnConf.get(
        YarnConfiguration.YARN_HTTP_POLICY_KEY,
        YarnConfiguration.YARN_HTTP_POLICY_DEFAULT
      )
      val user = Utils.getCurrentUserName()
      val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
      val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user"
      logDebug(s"Base URL for logs: $baseUrl")
      driverLogs = Some(Map(
        "stderr" -> s"$baseUrl/stderr?start=-4096",
        "stdout" -> s"$baseUrl/stdout?start=-4096"))
    } catch {
      case e: Exception =>
        logInfo("Error while building AM log links, so AM" +
          " logs link will not appear in application UI", e)
    }
    driverLogs
  }
} 
Example 133
Source File: YarnScheduler.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.apache.hadoop.yarn.util.RackResolver

import org.apache.log4j.{Level, Logger}

import org.apache.spark._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils

private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) {

  // RackResolver logs an INFO message whenever it resolves a rack, which is way too often.
  //RackResolver在解析机架时会记录INFO消息,这种情况太常见了
  if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
    Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
  }

  // By default, rack is unknown
  //默认情况下,机架未知
  override def getRackForHost(hostPort: String): Option[String] = {
    val host = Utils.parseHostPort(hostPort)._1
    Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation)
  }
} 
Example 134
Source File: RateController.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 135
Source File: JobSet.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import scala.collection.mutable.HashSet
import scala.util.Failure

import org.apache.spark.streaming.Time
import org.apache.spark.util.Utils


private[streaming]
case class JobSet(
    time: Time,
    jobs: Seq[Job],
    streamIdToInputInfo: Map[Int, StreamInputInfo] = Map.empty) {

  private val incompleteJobs = new HashSet[Job]()
  //当这jobset提交
  private val submissionTime = System.currentTimeMillis() // when this jobset was submitted
  //当这jobset第一份工作开始处理
  private var processingStartTime = -1L // when the first job of this jobset started processing
  //当这jobset最后的工作处理完
  private var processingEndTime = -1L // when the last job of this jobset finished processing

  jobs.zipWithIndex.foreach { case (job, i) => job.setOutputOpId(i) }
  incompleteJobs ++= jobs

  def handleJobStart(job: Job) {
    if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
  }

  def handleJobCompletion(job: Job) {
    incompleteJobs -= job
    if (hasCompleted) processingEndTime = System.currentTimeMillis()
  }

  def hasStarted: Boolean = processingStartTime > 0

  def hasCompleted: Boolean = incompleteJobs.isEmpty

  // Time taken to process all the jobs from the time they started processing
  //从他们开始处理的时间来处理所有的工作
  // (i.e. not including the time they wait in the streaming scheduler queue)
  def processingDelay: Long = processingEndTime - processingStartTime

  // Time taken to process all the jobs from the time they were submitted
  //从提交的时间来处理所有的工作时间
  // (i.e. including the time they wait in the streaming scheduler queue)
  def totalDelay: Long = {
    processingEndTime - time.milliseconds
  }

  def toBatchInfo: BatchInfo = {
    val failureReasons: Map[Int, String] = {
      if (hasCompleted) {
        jobs.filter(_.result.isFailure).map { job =>
          (job.outputOpId, Utils.exceptionString(job.result.asInstanceOf[Failure[_]].exception))
        }.toMap
      } else {
        Map.empty
      }
    }
    val binfo = new BatchInfo(
      time,
      streamIdToInputInfo,
      submissionTime,
      if (processingStartTime >= 0) Some(processingStartTime) else None,
      if (processingEndTime >= 0) Some(processingEndTime) else None
    )
    binfo.setFailureReason(failureReasons)
    binfo.setNumOutputOp(jobs.size)
    binfo
  }
} 
Example 136
Source File: FailureSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkFunSuite, Logging}
import org.apache.spark.util.Utils


class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  private val batchDuration: Duration = Milliseconds(1000)
  private val numBatches = 30
  private var directory: File = null

  before {
    directory = Utils.createTempDir()
  }

  after {
    if (directory != null) {
     //删除临时目录
      Utils.deleteRecursively(directory)
    }
    //停止所有活动实时流
    StreamingContext.getActive().foreach { _.stop() }
  }
  //多次失败map
  test("multiple failures with map") {
    MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration)
  }
  //多次失败updateStateByKey
  test("multiple failures with updateStateByKey") {
    MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration)
  }
} 
Example 137
Source File: BroadcastManager.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.broadcast

import java.util.concurrent.atomic.AtomicLong

import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.util.Utils

private[spark] class BroadcastManager(
    val isDriver: Boolean,
    conf: SparkConf,
    securityManager: SecurityManager)
  extends Logging {

  private var initialized = false
  private var broadcastFactory: BroadcastFactory = null

  initialize()//主要根据配置初始化broadcastFactory成员变量

  // Called by SparkContext or Executor before using Broadcast
  //调用SparkContext或者Executor前使用广播
  private def initialize() {
    synchronized {
      if (!initialized) {
         //广播的实现类
        val broadcastFactoryClass =	     
        conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
        broadcastFactory =
          Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
        // Initialize appropriate BroadcastFactory and BroadcastObject
        //调用初始化函数
        broadcastFactory.initialize(isDriver, conf, securityManager)
        initialized = true //初始化完成
      }
    }
  }

  def stop() {
    broadcastFactory.stop()
  }
  //广播变更ID
  private val nextBroadcastId = new AtomicLong(0)

  def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
    broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
  }

  def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
  }
} 
Example 138
Source File: SerializableWritable.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.io._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.Writable

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

@DeveloperApi
class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {

  def value: T = t

  override def toString: String = t.toString

  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
    out.defaultWriteObject()
    new ObjectWritable(t).write(out)
  }

  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    val ow = new ObjectWritable()
    ow.setConf(new Configuration(false))
    ow.readFields(in)
    t = ow.get().asInstanceOf[T]
  }
} 
Example 139
Source File: Message.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.nio

import java.net.InetSocketAddress
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import com.google.common.base.Charsets.UTF_8

import org.apache.spark.util.Utils

private[nio] abstract class Message(val typ: Long, val id: Int) {
  var senderAddress: InetSocketAddress = null
  var started = false
  var startTime = -1L
  var finishTime = -1L
  var isSecurityNeg = false
  var hasError = false

  def size: Int

  def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]

  def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]

  def timeTaken(): String = (finishTime - startTime).toString + " ms"

  override def toString: String = {
    this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
  }
}


private[nio] object Message {
  val BUFFER_MESSAGE = 1111111111L

  var lastId = 1

  def getNewId(): Int = synchronized {
    lastId += 1
    if (lastId == 0) {
      lastId += 1
    }
    lastId
  }

  def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
    if (dataBuffers == null) {
      return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
    }
    if (dataBuffers.exists(_ == null)) {
      throw new Exception("Attempting to create buffer message with null buffer")
    }
    new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
  }

  def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
    createBufferMessage(dataBuffers, 0)

  def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
    if (dataBuffer == null) {
      //ByteBuffer.allocate在能够读和写之前,必须有一个缓冲区,用静态方法 allocate() 来分配缓冲区
      createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
    } else {
      createBufferMessage(Array(dataBuffer), ackId)
    }
  }

  def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
    createBufferMessage(dataBuffer, 0)

  def createBufferMessage(ackId: Int): BufferMessage = {
    createBufferMessage(new Array[ByteBuffer](0), ackId)
  }

  
  def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
    val exceptionString = Utils.exceptionString(exception)
    val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8))
    val errorMessage = createBufferMessage(serializedExceptionString, ackId)
    errorMessage.hasError = true
    errorMessage
  }

  def create(header: MessageChunkHeader): Message = {
    val newMessage: Message = header.typ match {
      case BUFFER_MESSAGE => new BufferMessage(header.id,
        //ByteBuffer.allocate在能够读和写之前,必须有一个缓冲区,用静态方法 allocate() 来分配缓冲区
        ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
    }
    newMessage.hasError = header.hasError
    newMessage.senderAddress = header.address
    newMessage
  }
} 
Example 140
Source File: MetricsConfig.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.metrics

import java.io.{FileInputStream, InputStream}
import java.util.Properties

import scala.collection.mutable
import scala.util.matching.Regex

import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkConf}

private[spark] class MetricsConfig(conf: SparkConf) extends Logging {

  private val DEFAULT_PREFIX = "*"
  private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r
  private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties"

  private[metrics] val properties = new Properties()
  private[metrics] var propertyCategories: mutable.HashMap[String, Properties] = null

  private def setDefaultProperties(prop: Properties) {
    prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet")
    prop.setProperty("*.sink.servlet.path", "/metrics/json")
    prop.setProperty("master.sink.servlet.path", "/metrics/master/json")
    prop.setProperty("applications.sink.servlet.path", "/metrics/applications/json")
  }

  def initialize() {
    // Add default properties in case there's no properties file
    // 添加默认属性的情况下,没有任何属性文件
    setDefaultProperties(properties)

    loadPropertiesFromFile(conf.getOption("spark.metrics.conf"))

    // Also look for the properties in provided Spark configuration
    //还要查找提供的Spark配置中的属性
    val prefix = "spark.metrics.conf."
    conf.getAll.foreach {
      case (k, v) if k.startsWith(prefix) =>
        properties.setProperty(k.substring(prefix.length()), v)
      case _ =>
    }

    propertyCategories = subProperties(properties, INSTANCE_REGEX)
    if (propertyCategories.contains(DEFAULT_PREFIX)) {
      import scala.collection.JavaConversions._

      val defaultProperty = propertyCategories(DEFAULT_PREFIX)
      for { (inst, prop) <- propertyCategories
            if (inst != DEFAULT_PREFIX)
            (k, v) <- defaultProperty
            if (prop.getProperty(k) == null) } {
        prop.setProperty(k, v)
      }
    }
  }
//使用正则匹配properties中以source.开头的属性,然后将属性中的source反映得到的实例加入HashMap
  def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = {
    val subProperties = new mutable.HashMap[String, Properties]
    import scala.collection.JavaConversions._
    prop.foreach { kv =>
      if (regex.findPrefixOf(kv._1).isDefined) {
        val regex(prefix, suffix) = kv._1
        subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2)
      }
    }
    subProperties
  }

  def getInstance(inst: String): Properties = {
    propertyCategories.get(inst) match {
      case Some(s) => s
      case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties)
    }
  }

  
  private[this] def loadPropertiesFromFile(path: Option[String]): Unit = {
    var is: InputStream = null
    try {
      is = path match {
        case Some(f) => new FileInputStream(f)
        case None => Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)
      }

      if (is != null) {
        properties.load(is)
      }
    } catch {
      case e: Exception =>
        val file = path.getOrElse(DEFAULT_METRICS_CONF_FILENAME)
        logError(s"Error loading configuration file $file", e)
    } finally {
      if (is != null) {
        is.close()
      }
    }
  }

} 
Example 141
Source File: PythonGatewayServer.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    //System.getenv()和System.getProperties()的区别
    //System.getenv() 返回系统环境变量值 设置系统环境变量:当前登录用户主目录下的".bashrc"文件中可以设置系统环境变量
    //System.getProperties() 返回Java进程变量值 通过命令行参数的"-D"选项
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 142
Source File: PythonPartitioner.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import org.apache.spark.Partitioner
import org.apache.spark.util.Utils



private[spark] class PythonPartitioner(
  override val numPartitions: Int,
  val pyPartitionFunctionId: Long)
  extends Partitioner {

  override def getPartition(key: Any): Int = key match {
    case null => 0
    // we don't trust the Python partition function to return valid partition ID's so
    // let's do a modulo numPartitions in any case
    case key: Long => Utils.nonNegativeMod(key.toInt, numPartitions)
    case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions)
  }

  override def equals(other: Any): Boolean = other match {
    case h: PythonPartitioner =>
      h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
    case _ =>
      false
  }

  override def hashCode: Int = 31 * numPartitions + pyPartitionFunctionId.hashCode
} 
Example 143
Source File: MesosClusterDispatcherArguments.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.mesos

import org.apache.spark.SparkConf
import org.apache.spark.util.{IntParam, Utils}


private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) {
  var host = Utils.localHostName()
  var port = 7077
  var name = "Spark Cluster"
  var webUiPort = 8081
  var masterUrl: String = _
  var zookeeperUrl: Option[String] = None
  var propertiesFile: String = _

  parse(args.toList)

  propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)

  private def parse(args: List[String]): Unit = args match {
    case ("--host" | "-h") :: value :: tail =>
      Utils.checkHost(value, "Please use hostname " + value)
      host = value
      parse(tail)

    case ("--port" | "-p") :: IntParam(value) :: tail =>
      port = value
      parse(tail)

    case ("--webui-port" | "-p") :: IntParam(value) :: tail =>
      webUiPort = value
      parse(tail)

    case ("--zk" | "-z") :: value :: tail =>
      zookeeperUrl = Some(value)
      parse(tail)

    case ("--master" | "-m") :: value :: tail =>
      if (!value.startsWith("mesos://")) {
        // scalastyle:off println
        System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)")
        // scalastyle:on println
        System.exit(1)
      }
      masterUrl = value.stripPrefix("mesos://")
      parse(tail)

    case ("--name") :: value :: tail =>
      name = value
      parse(tail)

    case ("--properties-file") :: value :: tail =>
      propertiesFile = value
      parse(tail)

    case ("--help") :: tail =>
      printUsageAndExit(0)

    case Nil => {
      if (masterUrl == null) {
        // scalastyle:off println
        System.err.println("--master is required")
        // scalastyle:on println
        printUsageAndExit(1)
      }
    }

    case _ =>
      printUsageAndExit(1)
  }

  private def printUsageAndExit(exitCode: Int): Unit = {
    // scalastyle:off println
    System.err.println(
      "Usage: MesosClusterDispatcher [options]\n" +
        "\n" +
        "Options:\n" +
        "  -h HOST, --host HOST    Hostname to listen on\n" +
        "  -p PORT, --port PORT    Port to listen on (default: 7077)\n" +
        "  --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" +
        "  --name NAME             Framework name to show in Mesos UI\n" +
        "  -m --master MASTER      URI for connecting to Mesos master\n" +
        "  -z --zk ZOOKEEPER       Comma delimited URLs for connecting to \n" +
        "                          Zookeeper for persistence\n" +
        "  --properties-file FILE  Path to a custom Spark properties file.\n" +
        "                          Default is conf/spark-defaults.conf.")
    // scalastyle:on println
    System.exit(exitCode)
  }
} 
Example 144
Source File: TestClient.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.client

import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.util.Utils

private[spark] object TestClient {

  private class TestListener extends AppClientListener with Logging {
    def connected(id: String) {
      logInfo("Connected to master, got app ID " + id)
    }

    def disconnected() {
      logInfo("Disconnected from master")
      System.exit(0)
    }

    def dead(reason: String) {
      logInfo("Application died with error: " + reason)
      System.exit(0)
    }

    def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}

    def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
  }

  def main(args: Array[String]) {
    val url = if(args.isEmpty) "127.0.0.1" else args(0)
    
    val conf = new SparkConf
    val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf))
    val executorClassnamea = TestExecutor.getClass.getCanonicalName
    println("====executorClassname======"+executorClassnamea)
    //stripSuffix返回这个字符串,给定的`suffix`剥离。 如果这个字符串不以`suffix'结尾,那么它不会被返回
    val executorClassname = TestExecutor.getClass.getCanonicalName.stripSuffix("$")
    println("====executorClassname======"+executorClassname)
    val desc = new ApplicationDescription("TestClient", Some(1), 512,
      Command(executorClassname, Seq(), Map(), Seq(), Seq(), Seq()), "ignored")
    val listener = new TestListener
    val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf)
    client.start()
    rpcEnv.awaitTermination()
  }
} 
Example 145
Source File: SubmitRestProtocolRequest.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.rest

import scala.util.Try

import org.apache.spark.util.Utils


  private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = {
    sparkProperties.get(key).foreach { value =>
      //Scala2.10提供了Try来更优雅的实现这一功能。对于有可能抛出异常的操作。我们可以使用Try来包裹它,得到Try的子类Success或者Failure,
      // 如果计算成功,返回Success的实例,如果抛出异常,返回Failure并携带相关信息
      //Success和Failure 是Try的子类,getOrElse如果不存在则返回一个默认值,即抛出SubmitRestProtocolException异常
      Try(convert(value)).getOrElse {
        throw new SubmitRestProtocolException(
          s"Property '$key' expected $valueType value: actual was '$value'.")
      }
    }
  }
} 
Example 146
Source File: FileSystemPersistenceEngine.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import java.io._

import scala.reflect.ClassTag

import org.apache.spark.Logging
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
import org.apache.spark.util.Utils

private[master] class FileSystemPersistenceEngine(
    val dir: String,
    val serializer: Serializer)
  extends PersistenceEngine with Logging {

  new File(dir).mkdir()

  override def persist(name: String, obj: Object): Unit = {
    serializeIntoFile(new File(dir + File.separator + name), obj)
  }

  override def unpersist(name: String): Unit = {
    new File(dir + File.separator + name).delete()
  }

  override def read[T: ClassTag](prefix: String): Seq[T] = {
    val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix))
    files.map(deserializeFromFile[T])
  }

  private def serializeIntoFile(file: File, value: AnyRef) {
    val created = file.createNewFile()
    if (!created) { throw new IllegalStateException("Could not create file: " + file) }
    val fileOut = new FileOutputStream(file)
    var out: SerializationStream = null
    Utils.tryWithSafeFinally {
      out = serializer.newInstance().serializeStream(fileOut)
      out.writeObject(value)
    } {
      fileOut.close()
      if (out != null) {
        out.close()
      }
    }
  }

  private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
    val fileIn = new FileInputStream(file)
    var in: DeserializationStream = null
    try {
      in = serializer.newInstance().deserializeStream(fileIn)
      in.readObject[T]()
    } finally {
      fileIn.close()
      if (in != null) {
        in.close()
      }
    }
  }

} 
Example 147
Source File: DriverInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import java.util.Date

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.DriverDescription
import org.apache.spark.util.Utils

private[deploy] class DriverInfo(
    val startTime: Long,
    val id: String,
    val desc: DriverDescription,
    val submitDate: Date)
  extends Serializable {

  @transient var state: DriverState.Value = DriverState.SUBMITTED
  
  @transient var worker: Option[WorkerInfo] = None

  init()

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init(): Unit = {
    state = DriverState.SUBMITTED
    worker = None
    exception = None
  }
} 
Example 148
Source File: ApplicationInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import java.util.Date

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.deploy.ApplicationDescription
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class ApplicationInfo(
  val startTime: Long,
  val id: String,
  val desc: ApplicationDescription,
  val submitDate: Date,
  val driver: RpcEndpointRef,
  defaultCores: Int)
    extends Serializable {
  //枚举类型赋值
  @transient var state: ApplicationState.Value = _
  @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _
  @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _
  @transient var coresGranted: Int = _
  @transient var endTime: Long = _
  @transient var appSource: ApplicationSource = _

  // A cap on the number of executors this application can have at any given time.
  //执行者的数量这个应用程序可以在任何给定的时间
  // By default, this is infinite. Only after the first allocation request is issued by the
  // application will this be set to a finite value. This is used for dynamic allocation.
  //默认情况下,这是无限的,只有在应用程序发出第一个分配请求之后,这将被设置为有限的值,这用于动态分配
  @transient private[master] var executorLimit: Int = _

  @transient private var nextExecutorId: Int = _

  init() //初始化方法

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }
  
  private[deploy] def getExecutorLimit: Int = executorLimit

  def duration: Long = {
    if (endTime != -1) {
      endTime - startTime
    } else {
      System.currentTimeMillis() - startTime
    }
  }

} 
Example 149
Source File: MasterArguments.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import org.apache.spark.SparkConf
import org.apache.spark.util.{IntParam, Utils}


  private def printUsageAndExit(exitCode: Int) {
    // scalastyle:off println
    System.err.println(
      "Usage: Master [options]\n" +
      "\n" +
      "Options:\n" +
      "  -i HOST, --ip HOST     Hostname to listen on (deprecated, please use --host or -h) \n" +
      "  -h HOST, --host HOST   Hostname to listen on\n" +
      "  -p PORT, --port PORT   Port to listen on (default: 7077)\n" +
      "  --webui-port PORT      Port for web UI (default: 8080)\n" +
      "  --properties-file FILE Path to a custom Spark properties file.\n" +
      "                         Default is conf/spark-defaults.conf.")
    // scalastyle:on println
    System.exit(exitCode)
  }
} 
Example 150
Source File: ApplicationPage.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import javax.servlet.http.HttpServletRequest

import scala.xml.Node

import org.apache.spark.deploy.ExecutorState
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.ExecutorDesc
import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils

private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") {

  private val master = parent.masterEndpointRef

  
  def render(request: HttpServletRequest): Seq[Node] = {
    val appId = request.getParameter("appId")
    val state = master.askWithRetry[MasterStateResponse](RequestMasterState)
    val app = state.activeApps.find(_.id == appId).getOrElse({
      state.completedApps.find(_.id == appId).getOrElse(null)
    })
    if (app == null) {
      val msg = <div class="row-fluid">No running application with ID {appId}</div>
      return UIUtils.basicSparkPage(msg, "Not Found")
    }

    val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs")
    val allExecutors = (app.executors.values ++ app.removedExecutors).toSet.toSeq
    // This includes executors that are either still running or have exited cleanly
    //这包括仍在运行或已经完全退出的执行者
    val executors = allExecutors.filter { exec =>
      !ExecutorState.isFinished(exec.state) || exec.state == ExecutorState.EXITED
    }
    val removedExecutors = allExecutors.diff(executors)
    val executorsTable = UIUtils.listingTable(executorHeaders, executorRow, executors)
    val removedExecutorsTable = UIUtils.listingTable(executorHeaders, executorRow, removedExecutors)

    val content =
      <div class="row-fluid">
        <div class="span12">
          <ul class="unstyled">
            <li><strong>ID:</strong> {app.id}</li>
            <li><strong>Name:</strong> {app.desc.name}</li>
            <li><strong>User:</strong> {app.desc.user}</li>
            <li><strong>Cores:</strong>
            {
              if (app.desc.maxCores.isEmpty) {
                "Unlimited (%s granted)".format(app.coresGranted)
              } else {
                "%s (%s granted, %s left)".format(
                  app.desc.maxCores.get, app.coresGranted, app.coresLeft)
              }
            }
            </li>
            <li>
              <strong>Executor Memory:</strong>
              {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
            </li>
            <li><strong>Submit Date:</strong> {app.submitDate}</li>
            <li><strong>State:</strong> {app.state}</li>
            <li><strong><a href={app.desc.appUiUrl}>Application Detail UI</a></strong></li>
          </ul>
        </div>
      </div>

      <div class="row-fluid"> <!-- Executors -->
        <div class="span12">
          <h4> Executor Summary </h4>
          {executorsTable}
          {
            if (removedExecutors.nonEmpty) {
              <h4> Removed Executors </h4> ++
              removedExecutorsTable
            }
          }
        </div>
      </div>;
    UIUtils.basicSparkPage(content, "Application: " + app.desc.name)
  }

  private def executorRow(executor: ExecutorDesc): Seq[Node] = {
    <tr>
      <td>{executor.id}</td>
      <td>
        <a href={executor.worker.webUiAddress}>{executor.worker.id}</a>
      </td>
      <td>{executor.cores}</td>
      <td>{executor.memory}</td>
      <td>{executor.state}</td>
      <td>
        <a href={"%s/logPage?appId=%s&executorId=%s&logType=stdout"
          .format(executor.worker.webUiAddress, executor.application.id, executor.id)}>stdout</a>
        <a href={"%s/logPage?appId=%s&executorId=%s&logType=stderr"
          .format(executor.worker.webUiAddress, executor.application.id, executor.id)}>stderr</a>
      </td>
    </tr>
  }
} 
Example 151
Source File: WorkerInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import scala.collection.mutable

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class WorkerInfo(
    val id: String, //Id标识
    val host: String,//Worker的IP
    val port: Int,//Worker的端口
    val cores: Int,//Worker节点的CPU
    val memory: Int,//Worker节点的内存
    val endpoint: RpcEndpointRef,
    val webUiPort: Int,
    val publicAddress: String)
  extends Serializable {

  Utils.checkHost(host, "Expected hostname")
  assert (port > 0)

  @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
  @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
  @transient var state: WorkerState.Value = _
  @transient var coresUsed: Int = _
  @transient var memoryUsed: Int = _

  @transient var lastHeartbeat: Long = _

  init()//初始化数据

  def coresFree: Int = cores - coresUsed//
  def memoryFree: Int = memory - memoryUsed//

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }

  private def init() {
    executors = new mutable.HashMap
    drivers = new mutable.HashMap
    state = WorkerState.ALIVE//活动状态
    coresUsed = 0
    memoryUsed = 0
    lastHeartbeat = System.currentTimeMillis()
  }

  def hostPort: String = {
    assert (port > 0)
    host + ":" + port
  }

  def addExecutor(exec: ExecutorDesc) {
    executors(exec.fullId) = exec
    coresUsed += exec.cores
    memoryUsed += exec.memory
  }

  def removeExecutor(exec: ExecutorDesc) {
    if (executors.contains(exec.fullId)) {
      executors -= exec.fullId
      coresUsed -= exec.cores
      memoryUsed -= exec.memory
    }
  }

  def hasExecutor(app: ApplicationInfo): Boolean = {
    executors.values.exists(_.application == app)
  }

  def addDriver(driver: DriverInfo) {
    drivers(driver.id) = driver
    memoryUsed += driver.desc.mem
    coresUsed += driver.desc.cores
  }

  def removeDriver(driver: DriverInfo) {
    drivers -= driver.id
    memoryUsed -= driver.desc.mem
    coresUsed -= driver.desc.cores
  }

  def webUiAddress : String = {
    "http://" + this.publicAddress + ":" + this.webUiPort
  }

  def setState(state: WorkerState.Value): Unit = {
    this.state = state
  }

  def isAlive(): Boolean = this.state == WorkerState.ALIVE
} 
Example 152
Source File: ClientArguments.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import java.net.{URI, URISyntaxException}

import scala.collection.mutable.ListBuffer

import org.apache.log4j.Level
import org.apache.spark.util.{IntParam, MemoryParam, Utils}


  private def printUsageAndExit(exitCode: Int) {
    // TODO: It wouldn't be too hard to allow users to submit their app and dependency jars
    //       separately similar to in the YARN client.
    val usage =
    s"""
       |Usage: DriverClient [options] launch <active-master> <jar-url> <main-class> [driver options]
       |Usage: DriverClient kill <active-master> <driver-id>
       |
      |Options:
       |   -c CORES, --cores CORES        Number of cores to request (default: $DEFAULT_CORES)
       |   -m MEMORY, --memory MEMORY     Megabytes of memory to request (default: $DEFAULT_MEMORY)
       |   -s, --supervise                Whether to restart the driver on failure
       |                                  (default: $DEFAULT_SUPERVISE)
       |   -v, --verbose                  Print more debugging output
     """.stripMargin
    // scalastyle:off println
    System.err.println(usage)
    // scalastyle:on println
    System.exit(exitCode)
  }
}

private[deploy] object ClientArguments {
  val DEFAULT_CORES = 1
  val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB
  val DEFAULT_SUPERVISE = false

  def isValidJarUrl(s: String): Boolean = {
    try {
      val uri = new URI(s)
      uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar")
    } catch {
      case _: URISyntaxException => false
    }
  }
} 
Example 153
Source File: DriverWrapper.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import java.io.File

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}


      case workerUrl :: userJar :: mainClass :: extraArgs =>
        val conf = new SparkConf()
        val rpcEnv = RpcEnv.create("Driver",
          Utils.localHostName(), 0, conf, new SecurityManager(conf))
        rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl))
        //Thread.currentThread().getContextClassLoader,可以获取当前线程的引用,getContextClassLoader用来获取线程的上下文类加载器
        val currentLoader = Thread.currentThread.getContextClassLoader
        val userJarUrl = new File(userJar).toURI().toURL()
        val loader =
          if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
            new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader)
          } else {
            new MutableURLClassLoader(Array(userJarUrl), currentLoader)
          }
        //Thread.currentThread().getContextClassLoader,可以获取当前线程的引用,getContextClassLoader用来获取线程的上下文类加载器
        Thread.currentThread.setContextClassLoader(loader)

        // Delegate to supplied main class
        //提供给主类的委托
        val clazz = Utils.classForName(mainClass)
        val mainMethod = clazz.getMethod("main", classOf[Array[String]])
        mainMethod.invoke(null, extraArgs.toArray[String])

        rpcEnv.shutdown()

      case _ =>
        // scalastyle:off println
        System.err.println("Usage: DriverWrapper <workerUrl> <userJar> <driverMainClass> [options]")
        // scalastyle:on println
        System.exit(-1)
    }
  }
} 
Example 154
Source File: HistoryServerArguments.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.history

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.Utils


private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
  extends Logging {
  private var propertiesFile: String = null

  parse(args.toList)

  private def parse(args: List[String]): Unit = {
    args match {
      case ("--dir" | "-d") :: value :: tail =>
        logWarning("Setting log directory through the command line is deprecated as of " +
          "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.")
        conf.set("spark.history.fs.logDirectory", value)
        System.setProperty("spark.history.fs.logDirectory", value)
        parse(tail)

      case ("--help" | "-h") :: tail =>
        printUsageAndExit(0)

      case ("--properties-file") :: value :: tail =>
        propertiesFile = value
        parse(tail)
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      case Nil =>

      case _ =>
        printUsageAndExit(1)
    }
  }

   // This mutates the SparkConf, so all accesses to it must be made after this line
  //可变SparkConf,因此,所有访问它必须在这行之后
   Utils.loadDefaultSparkProperties(conf, propertiesFile)

  private def printUsageAndExit(exitCode: Int) {
    // scalastyle:off println
    System.err.println(
      """
      |Usage: HistoryServer [options]
      |
      |Options:
      |  --properties-file FILE      Path to a custom Spark properties file.
      |                              Default is conf/spark-defaults.conf.
      |
      |Configuration options can be set by setting the corresponding JVM system property.
      |History Server options are always available; additional options depend on the provider.
      |
      |History Server options:
      |
      |  spark.history.ui.port              Port where server will listen for connections
      |                                     (default 18080)
      |  spark.history.acls.enable          Whether to enable view acls for all applications
      |                                     (default false)
      |  spark.history.provider             Name of history provider class (defaults to
      |                                     file system-based provider)
      |  spark.history.retainedApplications Max number of application UIs to keep loaded in memory
      |                                     (default 50)
      |FsHistoryProvider options:
      |
      |  spark.history.fs.logDirectory      Directory where app logs are stored
      |                                     (default: file:/tmp/spark-events)
      |  spark.history.fs.updateInterval    How often to reload log data from storage
      |                                     (in seconds, default: 10)
      |""".stripMargin)
    // scalastyle:on println
    System.exit(exitCode)
  }

} 
Example 155
Source File: SparkHadoopMapReduceUtil.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mapreduce

import java.lang.{Boolean => JBoolean, Integer => JInteger}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
import org.apache.spark.util.Utils

private[spark]
trait SparkHadoopMapReduceUtil {
  def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
    val klass = firstAvailableClass(
        "org.apache.hadoop.mapreduce.task.JobContextImpl",  // hadoop2, hadoop2-yarn
        "org.apache.hadoop.mapreduce.JobContext")           // hadoop1
    val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID])
    ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
  }

  def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = {
    val klass = firstAvailableClass(
        "org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl",  // hadoop2, hadoop2-yarn
        "org.apache.hadoop.mapreduce.TaskAttemptContext")           // hadoop1
    val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID])
    ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
  }

  def newTaskAttemptID(
      jtIdentifier: String,
      jobId: Int,
      isMap: Boolean,
      taskId: Int,
      attemptId: Int): TaskAttemptID = {
    val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID")
    try {
      // First, attempt to use the old-style constructor that takes a boolean isMap
      // (not available in YARN)
      //首先,尝试使用带有boolean isMap的旧式构造函数(在YARN中不可用)
      val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
        classOf[Int], classOf[Int])
      ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
        new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
    } catch {
      case exc: NoSuchMethodException => {
        // If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
        //如果失败,请查找采用TaskType的新构造函数(在1.x中不可用)
        val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType")
          .asInstanceOf[Class[Enum[_]]]
        val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
          taskTypeClass, if (isMap) "MAP" else "REDUCE")
        val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
          classOf[Int], classOf[Int])
        ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
          new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
      }
    }
  }

  private def firstAvailableClass(first: String, second: String): Class[_] = {
    try {
      Utils.classForName(first)
    } catch {
      case e: ClassNotFoundException =>
        Utils.classForName(second)
    }
  }
} 
Example 156
Source File: TaskResult.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.io._
import java.nio.ByteBuffer

import scala.collection.Map
import scala.collection.mutable

import org.apache.spark.SparkEnv
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockId
import org.apache.spark.util.Utils

// Task result. Also contains updates to accumulator variables.
//任务结果,还包含累加器变量的更新,

private[spark] sealed trait TaskResult[T]


  def value(): T = {
    if (valueObjectDeserialized) {
      valueObject
    } else {
      // This should not run when holding a lock because it may cost dozens of seconds for a large
      // value.
      //这不应该在持有锁时运行,因为它可能花费数十秒钟值
      val resultSer = SparkEnv.get.serializer.newInstance()
      valueObject = resultSer.deserialize(valueBytes)
      valueObjectDeserialized = true
      valueObject
    }
  }
} 
Example 157
Source File: RDDInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDDOperationScope, RDD}
import org.apache.spark.util.Utils

@DeveloperApi
class RDDInfo(
    val id: Int,
    val name: String,
    val numPartitions: Int,//分区数
    var storageLevel: StorageLevel,//存储级别
    val parentIds: Seq[Int],//父RDD列表
    val scope: Option[RDDOperationScope] = None)
  extends Ordered[RDDInfo] {

  var numCachedPartitions = 0//缓存分区数
  var memSize = 0L//内存大小
  var diskSize = 0L//硬盘大小
  var externalBlockStoreSize = 0L//扩展块存储大小
  //是否缓存,
  def isCached: Boolean =
    (memSize + diskSize + externalBlockStoreSize > 0) && numCachedPartitions > 0

  override def toString: String = {
    import Utils.bytesToString
    ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " +
      "MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s").format(
        name, id, storageLevel.toString, numCachedPartitions, numPartitions,
        bytesToString(memSize), bytesToString(externalBlockStoreSize), bytesToString(diskSize))
  }

  override def compare(that: RDDInfo): Int = {
    this.id - that.id
  }
}

private[spark] object RDDInfo {
  def fromRdd(rdd: RDD[_]): RDDInfo = {
    val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd))
    val parentIds = rdd.dependencies.map(_.rdd.id)//返回RDD依赖的父列表
    new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope)
  }
} 
Example 158
Source File: BlockManagerId.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


  def apply(execId: String, host: String, port: Int): BlockManagerId =
    getCachedBlockManagerId(new BlockManagerId(execId, host, port))

  def apply(in: ObjectInput): BlockManagerId = {
    val obj = new BlockManagerId()
    obj.readExternal(in)//读取ObjectInput对象host和port_,executorId_
    getCachedBlockManagerId(obj)//获取BlockManagerId
  }

  val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()

  def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
    blockManagerIdCache.putIfAbsent(id, id)//put和putIfAbsent的区别就是一个是直接放入并替换,另一个是有就不替换
    blockManagerIdCache.get(id)//返回BlockManagerId对象
  }
} 
Example 159
Source File: ZippedWithIndexRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

private[spark]
class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index
}


  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
      Array[Long]()
    } else if (n == 1) {
      Array(0L)
    } else {
      prev.context.runJob(
        prev,
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)
    }
  }

  override def getPartitions: Array[Partition] = {
    firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
  }

  override def getPreferredLocations(split: Partition): Seq[String] =
    firstParent[T].preferredLocations(split.asInstanceOf[ZippedWithIndexRDDPartition].prev)

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    firstParent[T].iterator(split.prev, context).zipWithIndex.map { x =>
      (x._1, split.startIndex + x._2)
    }
  }
} 
Example 160
Source File: LocalRDDCheckpointData.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.Utils


  def transformStorageLevel(level: StorageLevel): StorageLevel = {
    // If this RDD is to be cached off-heap, fail fast since we cannot provide any
    // correctness guarantees about subsequent computations after the first one
    //如果这个RDD要被堆栈缓存,那么快速失败,因为我们不能在第一个之后提供关于后续计算的任何正确性保证
    if (level.useOffHeap) {
      throw new SparkException("Local checkpointing is not compatible with off-heap caching.")
    }

    StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication)
  }
} 
Example 161
Source File: CartesianRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.util.Utils

private[spark]

class CartesianPartition(
    idx: Int,
    @transient rdd1: RDD[_],
    @transient rdd2: RDD[_],
    s1Index: Int,
    s2Index: Int
  ) extends Partition {
  var s1 = rdd1.partitions(s1Index)
  var s2 = rdd2.partitions(s2Index)
  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    s1 = rdd1.partitions(s1Index)
    s2 = rdd2.partitions(s2Index)
    oos.defaultWriteObject()
  }
}

private[spark]
class CartesianRDD[T: ClassTag, U: ClassTag](
    sc: SparkContext,
    var rdd1 : RDD[T],
    var rdd2 : RDD[U])
  extends RDD[Pair[T, U]](sc, Nil)
  with Serializable {

  val numPartitionsInRdd2 = rdd2.partitions.length

  override def getPartitions: Array[Partition] = {
    // create the cross product split 创建交叉产品拆分
    val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length)
    for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
      val idx = s1.index * numPartitionsInRdd2 + s2.index
      array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)
    }
    array
  }

  override def getPreferredLocations(split: Partition): Seq[String] = {
    val currSplit = split.asInstanceOf[CartesianPartition]
    (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct
  }

  override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = {
    val currSplit = split.asInstanceOf[CartesianPartition]
    for (x <- rdd1.iterator(currSplit.s1, context);
         y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
  }

  override def getDependencies: Seq[Dependency[_]] = List(
    new NarrowDependency(rdd1) {
      def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2)
    },
    new NarrowDependency(rdd2) {
      def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2)
    }
  )

  override def clearDependencies() {
    super.clearDependencies()
    rdd1 = null
    rdd2 = null
  }
} 
Example 162
Source File: UnionRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient rdd: RDD[T],
    val parentRddIndex: Int,
    @transient parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    //在任务序列化时更新对父拆分的引用
    parentPartition = rdd.partitions(parentRddPartitionIndex)
    oos.defaultWriteObject()
  }
}

@DeveloperApi
class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](rdds.map(_.partitions.length).sum)
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length
    }
    deps
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)
  }

  override def getPreferredLocations(s: Partition): Seq[String] =
    s.asInstanceOf[UnionPartition[T]].preferredLocations()

  override def clearDependencies() {
    super.clearDependencies()
    rdds = null
  }
} 
Example 163
Source File: PartitionwiseSampledRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

private[spark]
class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index
}


private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    @transient preservesPartitioning: Boolean,
    @transient seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
  }

  override def getPreferredLocations(split: Partition): Seq[String] =
    firstParent[T].preferredLocations(split.asInstanceOf[PartitionwiseSampledRDDPartition].prev)

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.setSeed(split.seed)
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
  }
} 
Example 164
Source File: PythonBroadcastSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import scala.io.Source

import java.io.{PrintWriter, File}

import org.scalatest.Matchers

import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils

// This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize
// a PythonBroadcast:
class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext {
  test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") {
    val tempDir = Utils.createTempDir()
    val broadcastedString = "Hello, world!"
    def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = {
      val source = Source.fromFile(broadcast.path)
      val contents = source.mkString
      source.close()
      contents should be (broadcastedString)
    }
    try {
      val broadcastDataFile: File = {
        val file = new File(tempDir, "broadcastData")
        val printWriter = new PrintWriter(file)
        printWriter.write(broadcastedString)
        printWriter.close()
        file
      }
      val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath)
      assertBroadcastIsValid(broadcast)
      val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
      val deserializedBroadcast =
        Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance())
      assertBroadcastIsValid(deserializedBroadcast)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
} 
Example 165
Source File: PythonRunnerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils

class PythonRunnerSuite extends SparkFunSuite {

  // Test formatting a single path to be added to the PYTHONPATH
  //测试格式化要添加到PYTHONPATH的单个路径
    test("format path") {
      assert(PythonRunner.formatPath("spark.py") === "spark.py")
    assert(PythonRunner.formatPath("file:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("file:///spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:///spark.py") === "/spark.py")
    if (Utils.isWindows) {
      assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a\\b\\spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a b\\spark.py", testWindows = true) ===
        "C:/a b/spark.py")
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("one:two") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:s3:xtremeFS") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:/path/to/some.py") }
  }

  // Test formatting multiple comma-separated paths to be added to the PYTHONPATH
  test("format paths") {
    assert(PythonRunner.formatPaths("spark.py") === Array("spark.py"))
    assert(PythonRunner.formatPaths("file:/spark.py") === Array("/spark.py"))
    assert(PythonRunner.formatPaths("file:/app.py,local:/spark.py") ===
      Array("/app.py", "/spark.py"))
    assert(PythonRunner.formatPaths("me.py,file:/you.py,local:/we.py") ===
      Array("me.py", "/you.py", "/we.py"))
    if (Utils.isWindows) {
      assert(PythonRunner.formatPaths("C:\\a\\b\\spark.py", testWindows = true) ===
        Array("C:/a/b/spark.py"))
      assert(PythonRunner.formatPaths("C:\\free.py,pie.py", testWindows = true) ===
        Array("C:/free.py", "pie.py"))
      assert(PythonRunner.formatPaths("lovely.py,C:\\free.py,file:/d:/fry.py",
        testWindows = true) ===
        Array("lovely.py", "C:/free.py", "d:/fry.py"))
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("one:two,three") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("two,three,four:five:six") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
  }
} 
Example 166
Source File: TestClient.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.client

import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SecurityManager, SparkConf}

private[spark] object TestClient {

  private class TestListener extends AppClientListener with Logging {
    def connected(id: String) {
      logInfo("Connected to master, got app ID " + id)
    }

    def disconnected() {
      logInfo("Disconnected from master")
      System.exit(0)
    }

    def dead(reason: String) {
      logInfo("Application died with error: " + reason)
      System.exit(0)
    }

    def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}

    def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
  }

  def main(args: Array[String]) {
    val url = if(args.isEmpty) "172.0.0.1" else args(0)
    
    val conf = new SparkConf
    val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf))
    //stripSuffix去掉<string>字串中结尾的字符
    val executorClassname = TestExecutor.getClass.getCanonicalName.stripSuffix("$")
    println("====executorClassname======"+executorClassname)
    val desc = new ApplicationDescription("TestClient", Some(1), 512,
      Command(executorClassname, Seq(), Map(), Seq(), Seq(), Seq()), "ignored")
    val listener = new TestListener
    val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf)
    client.start()
    rpcEnv.awaitTermination()
  }
} 
Example 167
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Span, Seconds}

import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils


class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    super.beforeAll()
    val conf = new SparkConf()
      .set("master", "local[2,4]")
      .set("spark.speculation", "true")
      .set("spark.hadoop.mapred.output.committer.class",
        classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName)
    sc = new SparkContext("local[2, 4]", "test", conf)
  }

  test("exception thrown in OutputCommitter.commitTask()") {//异常抛出
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {
        Utils.deleteRecursively(tempDir)
      }
    }
  }
}

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new java.io.FileNotFoundException("Intentional exception")
    }
    super.commitTask(context)
  }
} 
Example 168
Source File: DriverSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.io.File

import org.scalatest.concurrent.Timeouts
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.time.SpanSugar._

import org.apache.spark.util.Utils

class DriverSuite extends SparkFunSuite with Timeouts {
  //driver退出后无需清理
  test("driver should exit after finishing without cleanup (SPARK-530)") {
    //System.getenv()和System.getProperties()的区别
    //System.getenv() 返回系统环境变量值 设置系统环境变量:当前登录用户主目录下的".bashrc"文件中可以设置系统环境变量
    //System.getProperties() 返回Java进程变量值 通过命令行参数的"-D"选项
    //getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val sparkHome = sys.props.getOrElse("spark.test.home", "/software/spark152")

   // val masters = Table("master", "local", "local-cluster[2,1,1024]")
   val masters = Table("master", "local", "local[*]")
    forAll(masters) { (master: String) =>
      val process = Utils.executeCommand(
        Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master),
        new File(sparkHome),
        Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
      failAfter(60 seconds) { process.waitFor() }
      // Ensure we still kill the process in case it timed out
      //它超时,确保我们仍然杀死过程
      //杀死子进程,Process对象表示的子进程被强行终止
      process.destroy()
    }
  }
}


object DriverWithoutCleanup {
  def main(args: Array[String]) {
    Utils.configTestLog4j("INFO")
    val conf = new SparkConf
    //val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf)
    val sc = new SparkContext("local", "DriverWithoutCleanup", conf)
    sc.parallelize(1 to 100, 4).count()
  }
} 
Example 169
Source File: DiskBlockManagerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils
//DiskBlockManager管理和维护了逻辑上的Block和存储在Disk上的物理的Block的映射。
//一般来说,一个逻辑的Block会根据它的BlockId生成的名字映射到一个物理上的文件
class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {//基本块的创建
    val blockId = new TestBlockId("test")
    //DiskBlockManager创建和维护逻辑块和物理磁盘位置之间的逻辑映射,默认情况下,一个块被映射到一个文件,其名称由其BlockId给出
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {//枚举块
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 170
Source File: LocalDirsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.File

import org.apache.spark.util.Utils
import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}



class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter {

  before {
    Utils.clearLocalRootDirs()
  }
    //返回一个有效的目录,即使一些本地目录丢失
  test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
    // Regression test for SPARK-2974
    assert(!new File("/NONEXISTENT_DIR").exists())
    val conf = new SparkConf(false)
      .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
    println("===="+new File(Utils.getLocalDir(conf)).getName)
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

  test("SPARK_LOCAL_DIRS override also affects driver") {//重写也会影响驱动程序
    // Regression test for SPARK-2975
    assert(!new File("/NONEXISTENT_DIR").exists())
    // SPARK_LOCAL_DIRS is a valid directory:
    //SPARK_LOCAL_DIRS是一个有效的目录
    class MySparkConf extends SparkConf(false) {
      override def getenv(name: String): String = {
        if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir")
        else super.getenv(name)
      }

      override def clone: SparkConf = {
        new MySparkConf().setAll(getAll)
      }
    }
    // spark.local.dir only contains invalid directories, but that's not a problem since
    // SPARK_LOCAL_DIRS will override it on both the driver and workers:
    //用于暂存空间的目录,该目录用于保存map输出文件或者转储RDD
    val conf = new MySparkConf().set("spark.local.dir", "/NONEXISTENT_PATH")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

} 
Example 171
Source File: JdbcRDDSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.sql._

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils

class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  before {
    Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver")
    val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
    try {

      try {
        val create = conn.createStatement
        create.execute("""
          CREATE TABLE FOO(
            ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
            DATA INTEGER
          )""")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
        (1 to 100).foreach { i =>
          insert.setInt(1, i * 2)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

      try {
        val create = conn.createStatement
        create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)")
        (1 to 100).foreach { i =>
          insert.setLong(1, 100000000000000000L +  4000000000000000L * i)
          insert.setInt(2, i)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

    } finally {
      conn.close()
    }
  }

  test("basic functionality") {//基本功能
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      //无参匿名函数
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
    "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
    1, 100, 3,//3是分区数
      //
    (r: ResultSet) => { r.getInt(1) } ).cache()

    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 10100)
  }

  test("large id overflow") {//大ID溢出
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?",
      1131544775L, 567279358897692673L, 20,
      (r: ResultSet) => { r.getInt(1) } ).cache()
    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 5050)
  }

  after {
    try {
      DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
    } catch {
      case se: SQLException if se.getSQLState == "08006" =>
        // Normal single database shutdown
        // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
    }
  }
} 
Example 172
Source File: HBaseSQLTableScan.scala    From Heracles   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hbase.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.hbase._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils


@DeveloperApi
case class HBaseSQLTableScan(
                              relation: HBaseRelation,
                              output: Seq[Attribute],
                              result: RDD[InternalRow]) extends SparkPlan {
  override def children: Seq[SparkPlan] = Nil

  override def outputPartitioning = {
    var ordering = List[SortOrder]()
    for (key <- relation.partitionKeys) {
      ordering = ordering :+ SortOrder(key, Ascending)
    }
    RangePartitioning(ordering, relation.partitions.size)
  }

  override protected def doExecute(): RDD[InternalRow] = {
    val schema = StructType.fromAttributes(output)
    result.mapPartitionsInternal { iter =>
      val proj = UnsafeProjection.create(schema)
      iter.map(proj)
    }
  }

  override def nodeName: String = getClass.getSimpleName

  override def argString: String =
    (Utils.truncatedString(output, "[", ", ", "]") :: Nil).mkString(", ")
} 
Example 173
Source File: UnaryTransformerExample.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.DoubleParam
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DataType, DataTypes}
import org.apache.spark.util.Utils
// $example off$


  object MyTransformer extends DefaultParamsReadable[MyTransformer]
  // $example off$

  def main(args: Array[String]) {
    val spark = SparkSession
      .builder()
      .appName("UnaryTransformerExample")
      .getOrCreate()

    // $example on$
    val myTransformer = new MyTransformer()
      .setShift(0.5)
      .setInputCol("input")
      .setOutputCol("output")

    // Create data, transform, and display it.
    val data = spark.range(0, 5).toDF("input")
      .select(col("input").cast("double").as("input"))
    val result = myTransformer.transform(data)
    println("Transformed by adding constant value")
    result.show()

    // Save and load the Transformer.
    val tmpDir = Utils.createTempDir()
    val dirName = tmpDir.getCanonicalPath
    myTransformer.write.overwrite().save(dirName)
    val sameTransformer = MyTransformer.load(dirName)

    // Transform the data to show the results are identical.
    println("Same transform applied from loaded model")
    val sameResult = sameTransformer.transform(data)
    sameResult.show()

    Utils.deleteRecursively(tmpDir)
    // $example off$

    spark.stop()
  }
}
// scalastyle:on println 
Example 174
Source File: DataFrameExample.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

import java.io.File

import scopt.OptionParser

import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.util.Utils


object DataFrameExample {

  case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
    extends AbstractParams[Params]

  def main(args: Array[String]) {
    val defaultParams = Params()

    val parser = new OptionParser[Params]("DataFrameExample") {
      head("DataFrameExample: an example app using DataFrame for ML.")
      opt[String]("input")
        .text("input path to dataframe")
        .action((x, c) => c.copy(input = x))
      checkConfig { params =>
        success
      }
    }

    parser.parse(args, defaultParams) match {
      case Some(params) => run(params)
      case _ => sys.exit(1)
    }
  }

  def run(params: Params): Unit = {
    val spark = SparkSession
      .builder
      .appName(s"DataFrameExample with $params")
      .getOrCreate()

    // Load input data
    println(s"Loading LIBSVM file with UDT from ${params.input}.")
    val df: DataFrame = spark.read.format("libsvm").load(params.input).cache()
    println("Schema from LIBSVM:")
    df.printSchema()
    println(s"Loaded training data as a DataFrame with ${df.count()} records.")

    // Show statistical summary of labels.
    val labelSummary = df.describe("label")
    labelSummary.show()

    // Convert features column to an RDD of vectors.
    val features = df.select("features").rdd.map { case Row(v: Vector) => v }
    val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
      (summary, feat) => summary.add(Vectors.fromML(feat)),
      (sum1, sum2) => sum1.merge(sum2))
    println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

    // Save the records in a parquet file.
    val tmpDir = Utils.createTempDir()
    val outputDir = new File(tmpDir, "dataframe").toString
    println(s"Saving to $outputDir as Parquet file.")
    df.write.parquet(outputDir)

    // Load the records back.
    println(s"Loading Parquet file with UDT from $outputDir.")
    val newDF = spark.read.parquet(outputDir)
    println("Schema from Parquet:")
    newDF.printSchema()

    spark.stop()
  }
}
// scalastyle:on println 
Example 175
Source File: StreamingTestExample.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.examples.mllib

import org.apache.spark.SparkConf
import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.util.Utils


object StreamingTestExample {

  def main(args: Array[String]) {
    if (args.length != 3) {
      // scalastyle:off println
      System.err.println(
        "Usage: StreamingTestExample " +
          "<dataDir> <batchDuration> <numBatchesTimeout>")
      // scalastyle:on println
      System.exit(1)
    }
    val dataDir = args(0)
    val batchDuration = Seconds(args(1).toLong)
    val numBatchesTimeout = args(2).toInt

    val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample")
    val ssc = new StreamingContext(conf, batchDuration)
    ssc.checkpoint {
      val dir = Utils.createTempDir()
      dir.toString
    }

    // $example on$
    val data = ssc.textFileStream(dataDir).map(line => line.split(",") match {
      case Array(label, value) => BinarySample(label.toBoolean, value.toDouble)
    })

    val streamingTest = new StreamingTest()
      .setPeacePeriod(0)
      .setWindowSize(0)
      .setTestMethod("welch")

    val out = streamingTest.registerStream(data)
    out.print()
    // $example off$

    // Stop processing if test becomes significant or we time out
    var timeoutCounter = numBatchesTimeout
    out.foreachRDD { rdd =>
      timeoutCounter -= 1
      val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _)
      if (timeoutCounter == 0 || anySignificant) rdd.context.stop()
    }

    ssc.start()
    ssc.awaitTermination()
  }
} 
Example 176
Source File: DriverSubmissionTest.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples

import scala.collection.JavaConverters._

import org.apache.spark.util.Utils


object DriverSubmissionTest {
  def main(args: Array[String]) {
    if (args.length < 1) {
      println("Usage: DriverSubmissionTest <seconds-to-sleep>")
      System.exit(0)
    }
    val numSecondsToSleep = args(0).toInt

    val env = System.getenv()
    val properties = Utils.getSystemProperties

    println("Environment variables containing SPARK_TEST:")
    env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println)

    println("System properties containing spark.test:")
    properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println)

    for (i <- 1 until numSecondsToSleep) {
      println(s"Alive for $i out of $numSecondsToSleep seconds")
      Thread.sleep(1000)
    }
  }
}
// scalastyle:on println 
Example 177
Source File: EdgeRDDSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.graphx

import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext {

  test("cache, getStorageLevel") {
    // test to see if getStorageLevel returns correct value after caching
    withSpark { sc =>
      val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3)))
      val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]]))
      assert(edges.getStorageLevel == StorageLevel.NONE)
      edges.cache()
      assert(edges.getStorageLevel == StorageLevel.MEMORY_ONLY)
    }
  }

  test("checkpointing") {
    withSpark { sc =>
      val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3)))
      val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]]))
      sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath)
      edges.checkpoint()

      // EdgeRDD not yet checkpointed
      assert(!edges.isCheckpointed)
      assert(!edges.isCheckpointedAndMaterialized)
      assert(!edges.partitionsRDD.isCheckpointed)
      assert(!edges.partitionsRDD.isCheckpointedAndMaterialized)

      val data = edges.collect().toSeq // force checkpointing

      // EdgeRDD shows up as checkpointed, but internally it is not.
      // Only internal partitionsRDD is checkpointed.
      assert(edges.isCheckpointed)
      assert(!edges.isCheckpointedAndMaterialized)
      assert(edges.partitionsRDD.isCheckpointed)
      assert(edges.partitionsRDD.isCheckpointedAndMaterialized)

      assert(edges.collect().toSeq ===  data) // test checkpointed RDD
    }
  }

} 
Example 178
Source File: GraphLoaderSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.graphx

import java.io.File
import java.io.FileOutputStream
import java.io.OutputStreamWriter
import java.nio.charset.StandardCharsets

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils

class GraphLoaderSuite extends SparkFunSuite with LocalSparkContext {

  test("GraphLoader.edgeListFile") {
    withSpark { sc =>
      val tmpDir = Utils.createTempDir()
      val graphFile = new File(tmpDir.getAbsolutePath, "graph.txt")
      val writer = new OutputStreamWriter(new FileOutputStream(graphFile), StandardCharsets.UTF_8)
      for (i <- (1 until 101)) writer.write(s"$i 0\n")
      writer.close()
      try {
        val graph = GraphLoader.edgeListFile(sc, tmpDir.getAbsolutePath)
        val neighborAttrSums = graph.aggregateMessages[Int](
          ctx => ctx.sendToDst(ctx.srcAttr),
          _ + _)
        assert(neighborAttrSums.collect.toSet === Set((0: VertexId, 100)))
      } finally {
        Utils.deleteRecursively(tmpDir)
      }
    }
  }
} 
Example 179
Source File: HashingTFSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF}
import org.apache.spark.sql.Row
import org.apache.spark.util.Utils

class HashingTFSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._
  import HashingTFSuite.murmur3FeatureIdx

  test("params") {
    ParamsSuite.checkParams(new HashingTF)
  }

  test("hashingTF") {
    val numFeatures = 100
    // Assume perfect hash when computing expected features.
    def idx: Any => Int = murmur3FeatureIdx(numFeatures)
    val data = Seq(
      ("a a b b c d".split(" ").toSeq,
        Vectors.sparse(numFeatures,
          Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))))
    )

    val df = data.toDF("words", "expected")
    val hashingTF = new HashingTF()
      .setInputCol("words")
      .setOutputCol("features")
      .setNumFeatures(numFeatures)
    val output = hashingTF.transform(df)
    val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
    require(attrGroup.numAttributes === Some(numFeatures))

    testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") {
      case Row(features: Vector, expected: Vector) =>
        assert(features ~== expected absTol 1e-14)
    }
  }

  test("applying binary term freqs") {
    val df = Seq((0, "a a b c c c".split(" ").toSeq)).toDF("id", "words")
    val n = 100
    val hashingTF = new HashingTF()
        .setInputCol("words")
        .setOutputCol("features")
        .setNumFeatures(n)
        .setBinary(true)
    val output = hashingTF.transform(df)
    val features = output.select("features").first().getAs[Vector](0)
    def idx: Any => Int = murmur3FeatureIdx(n)  // Assume perfect hash on input features
    val expected = Vectors.sparse(n,
      Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
    assert(features ~== expected absTol 1e-14)
  }

  test("read/write") {
    val t = new HashingTF()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setNumFeatures(10)
    testDefaultReadWrite(t)
  }

}

object HashingTFSuite {

  private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
    Utils.nonNegativeMod(MLlibHashingTF.murmur3Hash(term), numFeatures)
  }

} 
Example 180
Source File: MatrixFactorizationModelSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.recommendation

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext {

  val rank = 2
  var userFeatures: RDD[(Int, Array[Double])] = _
  var prodFeatures: RDD[(Int, Array[Double])] = _

  override def beforeAll(): Unit = {
    super.beforeAll()
    userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
    prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
  }

  test("constructor") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)

    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(1, userFeatures, prodFeatures)
    }

    val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
    }

    val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
    }
  }

  test("save/load") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString
    def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
      features.mapValues(_.toSeq).collect().toSet
    }
    try {
      model.save(sc, path)
      val newModel = MatrixFactorizationModel.load(sc, path)
      assert(newModel.rank === rank)
      assert(collect(newModel.userFeatures) === collect(userFeatures))
      assert(collect(newModel.productFeatures) === collect(prodFeatures))
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }

  test("batch predict API recommendProductsForUsers") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val topK = 10
    val recommendations = model.recommendProductsForUsers(topK).collectAsMap()

    assert(recommendations(0)(0).rating ~== 17.0 relTol 1e-14)
    assert(recommendations(1)(0).rating ~== 39.0 relTol 1e-14)
  }

  test("batch predict API recommendUsersForProducts") {
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val topK = 10
    val recommendations = model.recommendUsersForProducts(topK).collectAsMap()

    assert(recommendations(2)(0).user == 1)
    assert(recommendations(2)(0).rating ~== 39.0 relTol 1e-14)
    assert(recommendations(2)(1).user == 0)
    assert(recommendations(2)(1).rating ~== 17.0 relTol 1e-14)
  }
} 
Example 181
Source File: MLlibTestSparkContext.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import java.io.File

import org.scalatest.Suite

import org.apache.spark.SparkContext
import org.apache.spark.ml.util.TempDirectory
import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.util.Utils

trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
  @transient var spark: SparkSession = _
  @transient var sc: SparkContext = _
  @transient var checkpointDir: String = _

  override def beforeAll() {
    super.beforeAll()
    spark = SparkSession.builder
      .master("local[2]")
      .appName("MLlibUnitTest")
      .getOrCreate()
    sc = spark.sparkContext

    checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
    sc.setCheckpointDir(checkpointDir)
  }

  override def afterAll() {
    try {
      Utils.deleteRecursively(new File(checkpointDir))
      SparkSession.clearActiveSession()
      if (spark != null) {
        spark.stop()
      }
      spark = null
    } finally {
      super.afterAll()
    }
  }

  
  protected object testImplicits extends SQLImplicits {
    protected override def _sqlContext: SQLContext = self.spark.sqlContext
  }
} 
Example 182
Source File: RidgeRegressionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.regression

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext,
  MLlibTestSparkContext}
import org.apache.spark.util.Utils

private object RidgeRegressionSuite {

  
  val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}

class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {

  def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
    predictions.zip(input).map { case (prediction, expected) =>
      (prediction - expected.label) * (prediction - expected.label)
    }.sum / predictions.size
  }

  test("ridge regression can help avoid overfitting") {

    // For small number of examples and large variance of error distribution,
    // ridge regression should give smaller generalization error that linear regression.

    val numExamples = 50
    val numFeatures = 20

    // Pick weights as random values distributed uniformly in [-0.5, 0.5]
    val random = new Random(42)
    val w = Array.fill(numFeatures)(random.nextDouble() - 0.5)

    // Use half of data for training and other half for validation
    val data = LinearDataGenerator.generateLinearInput(3.0, w, 2 * numExamples, 42, 10.0)
    val testData = data.take(numExamples)
    val validationData = data.takeRight(numExamples)

    val testRDD = sc.parallelize(testData, 2).cache()
    val validationRDD = sc.parallelize(validationData, 2).cache()

    // First run without regularization.
    val linearReg = new LinearRegressionWithSGD()
    linearReg.optimizer.setNumIterations(200)
                       .setStepSize(1.0)

    val linearModel = linearReg.run(testRDD)
    val linearErr = predictionError(
        linearModel.predict(validationRDD.map(_.features)).collect(), validationData)

    val ridgeReg = new RidgeRegressionWithSGD()
    ridgeReg.optimizer.setNumIterations(200)
                      .setRegParam(0.1)
                      .setStepSize(1.0)
    val ridgeModel = ridgeReg.run(testRDD)
    val ridgeErr = predictionError(
        ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)

    // Ridge validation error should be lower than linear regression.
    assert(ridgeErr < linearErr,
      "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
  }

  test("model save/load") {
    val model = RidgeRegressionSuite.model

    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString

    // Save model, load it back, and compare.
    try {
      model.save(sc, path)
      val sameModel = RidgeRegressionModel.load(sc, path)
      assert(model.weights == sameModel.weights)
      assert(model.intercept == sameModel.intercept)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
}

class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {

  test("task size should be small in both training and prediction") {
    val m = 4
    val n = 200000
    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
      val random = new Random(idx)
      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
    }.cache()
    // If we serialize data directly in the task closure, the size of the serialized task would be
    // greater than 1MB and hence Spark would throw an error.
    val model = RidgeRegressionWithSGD.train(points, 2)
    val predictions = model.predict(points.map(_.features))
  }
} 
Example 183
Source File: SparkPodInitContainer.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.k8s

import java.io.File
import java.util.concurrent.TimeUnit

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.internal.Logging
import org.apache.spark.util.{ThreadUtils, Utils}


private[spark] class SparkPodInitContainer(
    sparkConf: SparkConf,
    fileFetcher: FileFetcher) extends Logging {

  private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE)
  private implicit val downloadExecutor = ExecutionContext.fromExecutorService(
    ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize))

  private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION))
  private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION))

  private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS)
  private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES)

  private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT)

  def run(): Unit = {
    logInfo(s"Downloading remote jars: $remoteJars")
    downloadFiles(
      remoteJars,
      jarsDownloadDir,
      s"Remote jars download directory specified at $jarsDownloadDir does not exist " +
        "or is not a directory.")

    logInfo(s"Downloading remote files: $remoteFiles")
    downloadFiles(
      remoteFiles,
      filesDownloadDir,
      s"Remote files download directory specified at $filesDownloadDir does not exist " +
        "or is not a directory.")

    downloadExecutor.shutdown()
    downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES)
  }

  private def downloadFiles(
      filesCommaSeparated: Option[String],
      downloadDir: File,
      errMessage: String): Unit = {
    filesCommaSeparated.foreach { files =>
      require(downloadDir.isDirectory, errMessage)
      Utils.stringToSeq(files).foreach { file =>
        Future[Unit] {
          fileFetcher.fetchFile(file, downloadDir)
        }
      }
    }
  }
}

private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) {

  def fetchFile(uri: String, targetDir: File): Unit = {
    Utils.fetchFile(
      url = uri,
      targetDir = targetDir,
      conf = sparkConf,
      securityMgr = securityManager,
      hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf),
      timestamp = System.currentTimeMillis(),
      useCache = false)
  }
}

object SparkPodInitContainer extends Logging {

  def main(args: Array[String]): Unit = {
    logInfo("Starting init-container to download Spark application dependencies.")
    val sparkConf = new SparkConf(true)
    if (args.nonEmpty) {
      Utils.loadDefaultSparkProperties(sparkConf, args(0))
    }

    val securityManager = new SparkSecurityManager(sparkConf)
    val fileFetcher = new FileFetcher(sparkConf, securityManager)
    new SparkPodInitContainer(sparkConf, fileFetcher).run()
    logInfo("Finished downloading application dependencies.")
  }
} 
Example 184
Source File: KubernetesUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.k8s

import java.io.File

import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder}

import org.apache.spark.SparkConf
import org.apache.spark.util.Utils

private[spark] object KubernetesUtils {

  
  def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = {
    uris.filter { uri =>
      val scheme = Utils.resolveURI(uri).getScheme
      scheme != "file" && scheme != "local"
    }
  }

  private def resolveFileUri(
      uri: String,
      fileDownloadPath: String,
      assumesDownloaded: Boolean): String = {
    val fileUri = Utils.resolveURI(uri)
    val fileScheme = Option(fileUri.getScheme).getOrElse("file")
    fileScheme match {
      case "local" =>
        fileUri.getPath
      case _ =>
        if (assumesDownloaded || fileScheme == "file") {
          val fileName = new File(fileUri.getPath).getName
          s"$fileDownloadPath/$fileName"
        } else {
          uri
        }
    }
  }
} 
Example 185
Source File: SparkPodInitContainerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.k8s

import java.io.File
import java.util.UUID

import com.google.common.base.Charsets
import com.google.common.io.Files
import org.mockito.Mockito
import org.scalatest.BeforeAndAfter
import org.scalatest.mockito.MockitoSugar._

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.util.Utils

class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter {

  private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt")
  private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt")

  private var downloadJarsDir: File = _
  private var downloadFilesDir: File = _
  private var downloadJarsSecretValue: String = _
  private var downloadFilesSecretValue: String = _
  private var fileFetcher: FileFetcher = _

  override def beforeAll(): Unit = {
    downloadJarsSecretValue = Files.toString(
      new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8)
    downloadFilesSecretValue = Files.toString(
      new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8)
  }

  before {
    downloadJarsDir = Utils.createTempDir()
    downloadFilesDir = Utils.createTempDir()
    fileFetcher = mock[FileFetcher]
  }

  after {
    downloadJarsDir.delete()
    downloadFilesDir.delete()
  }

  test("Downloads from remote server should invoke the file fetcher") {
    val sparkConf = getSparkConfForRemoteFileDownloads
    val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher)
    initContainerUnderTest.run()
    Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir)
    Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir)
    Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir)
  }

  private def getSparkConfForRemoteFileDownloads: SparkConf = {
    new SparkConf(true)
      .set(INIT_CONTAINER_REMOTE_JARS,
        "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar")
      .set(INIT_CONTAINER_REMOTE_FILES,
        "http://localhost:9000/file.txt")
      .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath)
      .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath)
  }

  private def createTempFile(extension: String): String = {
    val dir = Utils.createTempDir()
    val file = new File(dir, s"${UUID.randomUUID().toString}.$extension")
    Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8)
    file.getAbsolutePath
  }
} 
Example 186
Source File: MesosClusterDispatcher.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.mesos

import java.util.concurrent.CountDownLatch

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.mesos.config._
import org.apache.spark.deploy.mesos.ui.MesosClusterUI
import org.apache.spark.deploy.rest.mesos.MesosRestServer
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.cluster.mesos._
import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, SparkUncaughtExceptionHandler, Utils}


private[mesos] class MesosClusterDispatcher(
    args: MesosClusterDispatcherArguments,
    conf: SparkConf)
  extends Logging {

  private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host)
  private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase()
  logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode)

  private val engineFactory = recoveryMode match {
    case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory
    case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf)
    case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode)
  }

  private val scheduler = new MesosClusterScheduler(engineFactory, conf)

  private val server = new MesosRestServer(args.host, args.port, conf, scheduler)
  private val webUi = new MesosClusterUI(
    new SecurityManager(conf),
    args.webUiPort,
    conf,
    publicAddress,
    scheduler)

  private val shutdownLatch = new CountDownLatch(1)

  def start(): Unit = {
    webUi.bind()
    scheduler.frameworkUrl = conf.get(DISPATCHER_WEBUI_URL).getOrElse(webUi.activeWebUiUrl)
    scheduler.start()
    server.start()
  }

  def awaitShutdown(): Unit = {
    shutdownLatch.await()
  }

  def stop(): Unit = {
    webUi.stop()
    server.stop()
    scheduler.stop()
    shutdownLatch.countDown()
  }
}

private[mesos] object MesosClusterDispatcher
  extends Logging
  with CommandLineUtils {

  override def main(args: Array[String]) {
    Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler)
    Utils.initDaemon(log)
    val conf = new SparkConf
    val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf)
    conf.setMaster(dispatcherArgs.masterUrl)
    conf.setAppName(dispatcherArgs.name)
    dispatcherArgs.zookeeperUrl.foreach { z =>
      conf.set(RECOVERY_MODE, "ZOOKEEPER")
      conf.set(ZOOKEEPER_URL, z)
    }
    val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf)
    dispatcher.start()
    logDebug("Adding shutdown hook") // force eager creation of logger
    ShutdownHookManager.addShutdownHook { () =>
      logInfo("Shutdown hook is shutting down dispatcher")
      dispatcher.stop()
      dispatcher.awaitShutdown()
    }
    dispatcher.awaitShutdown()
  }
} 
Example 187
Source File: MesosClusterPersistenceEngine.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster.mesos

import scala.collection.JavaConverters._

import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
import org.apache.zookeeper.KeeperException.NoNodeException

import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkCuratorUtil
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[spark] class ZookeeperMesosClusterPersistenceEngine(
    baseDir: String,
    zk: CuratorFramework,
    conf: SparkConf)
  extends MesosClusterPersistenceEngine with Logging {
  private val WORKING_DIR =
    conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir

  SparkCuratorUtil.mkdir(zk, WORKING_DIR)

  def path(name: String): String = {
    WORKING_DIR + "/" + name
  }

  override def expunge(name: String): Unit = {
    zk.delete().forPath(path(name))
  }

  override def persist(name: String, obj: Object): Unit = {
    val serialized = Utils.serialize(obj)
    val zkPath = path(name)
    zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized)
  }

  override def fetch[T](name: String): Option[T] = {
    val zkPath = path(name)

    try {
      val fileData = zk.getData().forPath(zkPath)
      Some(Utils.deserialize[T](fileData))
    } catch {
      case e: NoNodeException => None
      case e: Exception =>
        logWarning("Exception while reading persisted file, deleting", e)
        zk.delete().forPath(zkPath)
        None
    }
  }

  override def fetchAll[T](): Iterable[T] = {
    zk.getChildren.forPath(WORKING_DIR).asScala.flatMap(fetch[T])
  }
} 
Example 188
Source File: YARNHadoopDelegationTokenManager.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn.security

import java.util.ServiceLoader

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.security.Credentials

import org.apache.spark.SparkConf
import org.apache.spark.deploy.security.HadoopDelegationTokenManager
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


  def obtainDelegationTokens(hadoopConf: Configuration, creds: Credentials): Long = {
    val superInterval = delegationTokenManager.obtainDelegationTokens(hadoopConf, creds)

    credentialProviders.values.flatMap { provider =>
      if (provider.credentialsRequired(hadoopConf)) {
        provider.obtainCredentials(hadoopConf, sparkConf, creds)
      } else {
        logDebug(s"Service ${provider.serviceName} does not require a token." +
          s" Check your configuration to see if security is disabled or not.")
        None
      }
    }.foldLeft(superInterval)(math.min)
  }

  private def getCredentialProviders: Map[String, ServiceCredentialProvider] = {
    val providers = loadCredentialProviders

    providers.
      filter { p => delegationTokenManager.isServiceEnabled(p.serviceName) }
      .map { p => (p.serviceName, p) }
      .toMap
  }

  private def loadCredentialProviders: List[ServiceCredentialProvider] = {
    ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader)
      .asScala
      .toList
  }
} 
Example 189
Source File: YarnRMClient.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import scala.collection.JavaConverters._

import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.webapp.util.WebAppUtils

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils


  def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
    val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt)
    val yarnMaxAttempts = yarnConf.getInt(
      YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
    sparkMaxAttempts match {
      case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
      case None => yarnMaxAttempts
    }
  }

} 
Example 190
Source File: YarnClusterSchedulerBackend.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.conf.YarnConfiguration

import org.apache.spark.SparkContext
import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils

private[spark] class YarnClusterSchedulerBackend(
    scheduler: TaskSchedulerImpl,
    sc: SparkContext)
  extends YarnSchedulerBackend(scheduler, sc) {

  override def start() {
    val attemptId = ApplicationMaster.getAttemptId
    bindToYarn(attemptId.getApplicationId(), Some(attemptId))
    super.start()
    totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sc.conf)
  }

  override def getDriverLogUrls: Option[Map[String, String]] = {
    var driverLogs: Option[Map[String, String]] = None
    try {
      val yarnConf = new YarnConfiguration(sc.hadoopConfiguration)
      val containerId = YarnSparkHadoopUtil.getContainerId

      val httpAddress = System.getenv(Environment.NM_HOST.name()) +
        ":" + System.getenv(Environment.NM_HTTP_PORT.name())
      // lookup appropriate http scheme for container log urls
      val yarnHttpPolicy = yarnConf.get(
        YarnConfiguration.YARN_HTTP_POLICY_KEY,
        YarnConfiguration.YARN_HTTP_POLICY_DEFAULT
      )
      val user = Utils.getCurrentUserName()
      val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
      val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user"
      logDebug(s"Base URL for logs: $baseUrl")
      driverLogs = Some(Map(
        "stdout" -> s"$baseUrl/stdout?start=-4096",
        "stderr" -> s"$baseUrl/stderr?start=-4096"))
    } catch {
      case e: Exception =>
        logInfo("Error while building AM log links, so AM" +
          " logs link will not appear in application UI", e)
    }
    driverLogs
  }
} 
Example 191
Source File: YarnScheduler.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.apache.hadoop.yarn.util.RackResolver
import org.apache.log4j.{Level, Logger}

import org.apache.spark._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils

private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) {

  // RackResolver logs an INFO message whenever it resolves a rack, which is way too often.
  if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
    Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
  }

  // By default, rack is unknown
  override def getRackForHost(hostPort: String): Option[String] = {
    val host = Utils.parseHostPort(hostPort)._1
    Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation)
  }
} 
Example 192
Source File: SparkAWSCredentialsBuilderSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kinesis

import org.apache.spark.streaming.TestSuiteBase
import org.apache.spark.util.Utils

class SparkAWSCredentialsBuilderSuite extends TestSuiteBase {
  private def builder = SparkAWSCredentials.builder

  private val basicCreds = BasicCredentials(
    awsAccessKeyId = "a-very-nice-access-key",
    awsSecretKey = "a-very-nice-secret-key")

  private val stsCreds = STSCredentials(
    stsRoleArn = "a-very-nice-role-arn",
    stsSessionName = "a-very-nice-secret-key",
    stsExternalId = Option("a-very-nice-external-id"),
    longLivedCreds = basicCreds)

  test("should build DefaultCredentials when given no params") {
    assert(builder.build() == DefaultCredentials)
  }

  test("should build BasicCredentials") {
    assertResult(basicCreds) {
      builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
        .build()
    }
  }

  test("should build STSCredentials") {
    // No external ID, default long-lived creds
    assertResult(stsCreds.copy(stsExternalId = None, longLivedCreds = DefaultCredentials)) {
      builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName)
        .build()
    }
    // Default long-lived creds
    assertResult(stsCreds.copy(longLivedCreds = DefaultCredentials)) {
      builder.stsCredentials(
          stsCreds.stsRoleArn,
          stsCreds.stsSessionName,
          stsCreds.stsExternalId.get)
        .build()
    }
    // No external ID, basic keypair for long-lived creds
    assertResult(stsCreds.copy(stsExternalId = None)) {
      builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName)
        .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
        .build()
    }
    // Basic keypair for long-lived creds
    assertResult(stsCreds) {
      builder.stsCredentials(
          stsCreds.stsRoleArn,
          stsCreds.stsSessionName,
          stsCreds.stsExternalId.get)
        .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
        .build()
    }
    // Order shouldn't matter
    assertResult(stsCreds) {
      builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
        .stsCredentials(
          stsCreds.stsRoleArn,
          stsCreds.stsSessionName,
          stsCreds.stsExternalId.get)
        .build()
    }
  }

  test("SparkAWSCredentials classes should be serializable") {
    assertResult(basicCreds) {
      Utils.deserialize[BasicCredentials](Utils.serialize(basicCreds))
    }
    assertResult(stsCreds) {
      Utils.deserialize[STSCredentials](Utils.serialize(stsCreds))
    }
    // Will also test if DefaultCredentials can be serialized
    val stsDefaultCreds = stsCreds.copy(longLivedCreds = DefaultCredentials)
    assertResult(stsDefaultCreds) {
      Utils.deserialize[STSCredentials](Utils.serialize(stsDefaultCreds))
    }
  }
} 
Example 193
Source File: FlumeTestUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.net.{InetSocketAddress, ServerSocket}
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.{List => JList}
import java.util.Collections

import scala.collection.JavaConverters._

import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.commons.lang3.RandomUtils
import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol}
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder}

import org.apache.spark.SparkConf
import org.apache.spark.util.Utils


  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }

} 
Example 194
Source File: EventTransformer.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.io.{ObjectInput, ObjectOutput}

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[streaming] object EventTransformer extends Logging {
  def readExternal(in: ObjectInput): (java.util.HashMap[CharSequence, CharSequence],
    Array[Byte]) = {
    val bodyLength = in.readInt()
    val bodyBuff = new Array[Byte](bodyLength)
    in.readFully(bodyBuff)

    val numHeaders = in.readInt()
    val headers = new java.util.HashMap[CharSequence, CharSequence]

    for (i <- 0 until numHeaders) {
      val keyLength = in.readInt()
      val keyBuff = new Array[Byte](keyLength)
      in.readFully(keyBuff)
      val key: String = Utils.deserialize(keyBuff)

      val valLength = in.readInt()
      val valBuff = new Array[Byte](valLength)
      in.readFully(valBuff)
      val value: String = Utils.deserialize(valBuff)

      headers.put(key, value)
    }
    (headers, bodyBuff)
  }

  def writeExternal(out: ObjectOutput, headers: java.util.Map[CharSequence, CharSequence],
    body: Array[Byte]) {
    out.writeInt(body.length)
    out.write(body)
    val numHeaders = headers.size()
    out.writeInt(numHeaders)
    for ((k, v) <- headers.asScala) {
      val keyBuff = Utils.serialize(k.toString)
      out.writeInt(keyBuff.length)
      out.write(keyBuff)
      val valBuff = Utils.serialize(v.toString)
      out.writeInt(valBuff.length)
      out.write(valBuff)
    }
  }
} 
Example 195
Source File: TestOutputStream.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.{IOException, ObjectInputStream}
import java.util.concurrent.ConcurrentLinkedQueue

import scala.reflect.ClassTag

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.{DStream, ForEachDStream}
import org.apache.spark.util.Utils


class TestOutputStream[T: ClassTag](parent: DStream[T],
    val output: ConcurrentLinkedQueue[Seq[T]] = new ConcurrentLinkedQueue[Seq[T]]())
  extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
    val collected = rdd.collect()
    output.add(collected)
  }, false) {

  // This is to clear the output buffer every it is read from a checkpoint
  @throws(classOf[IOException])
  private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
    ois.defaultReadObject()
    output.clear()
  }
} 
Example 196
Source File: KafkaWriter.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.types.{BinaryType, StringType}
import org.apache.spark.util.Utils


private[kafka010] object KafkaWriter extends Logging {
  val TOPIC_ATTRIBUTE_NAME: String = "topic"
  val KEY_ATTRIBUTE_NAME: String = "key"
  val VALUE_ATTRIBUTE_NAME: String = "value"

  override def toString: String = "KafkaWriter"

  def validateQuery(
      schema: Seq[Attribute],
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
      if (topic.isEmpty) {
        throw new AnalysisException(s"topic option required when no " +
          s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
          s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
      } else {
        Literal(topic.get, StringType)
      }
    ).dataType match {
      case StringType => // good
      case _ =>
        throw new AnalysisException(s"Topic type must be a String")
    }
    schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse(
      Literal(null, StringType)
    ).dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " +
          s"must be a String or BinaryType")
    }
    schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse(
      throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found")
    ).dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
          s"must be a String or BinaryType")
    }
  }

  def write(
      sparkSession: SparkSession,
      queryExecution: QueryExecution,
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    val schema = queryExecution.analyzed.output
    validateQuery(schema, kafkaParameters, topic)
    queryExecution.toRdd.foreachPartition { iter =>
      val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
      Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
        finallyBlock = writeTask.close())
    }
  }
} 
Example 197
Source File: QueryPartitionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import java.io.File
import java.sql.Timestamp

import com.google.common.io.Files
import org.apache.hadoop.fs.FileSystem

import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils

class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  import spark.implicits._

  test("SPARK-5068: query data when path doesn't exist") {
    withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) {
      val testData = sparkContext.parallelize(
        (1 to 10).map(i => TestData(i, i.toString))).toDF()
      testData.createOrReplaceTempView("testData")

      val tmpDir = Files.createTempDir()
      // create the table for test
      sql(s"CREATE TABLE table_with_partition(key int,value string) " +
        s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='1') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='2') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='3') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='4') " +
        "SELECT key,value FROM testData")

      // test for the exist path
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect
          ++ testData.toDF.collect ++ testData.toDF.collect)

      // delete the path of one partition
      tmpDir.listFiles
        .find { f => f.isDirectory && f.getName().startsWith("ds=") }
        .foreach { f => Utils.deleteRecursively(f) }

      // test for after delete the path
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect)

      sql("DROP TABLE IF EXISTS table_with_partition")
      sql("DROP TABLE IF EXISTS createAndInsertTest")
    }
  }

  test("SPARK-21739: Cast expression should initialize timezoneId") {
    withTable("table_with_timestamp_partition") {
      sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
      sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
        "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")

      // test for Cast expression in TableReader
      checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
        Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))

      // test for Cast expression in HiveTableScanExec
      checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
        "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
    }
  }
} 
Example 198
Source File: HiveClientBuilder.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import java.io.File

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.util.VersionInfo

import org.apache.spark.SparkConf
import org.apache.spark.util.Utils

private[client] object HiveClientBuilder {
  // In order to speed up test execution during development or in Jenkins, you can specify the path
  // of an existing Ivy cache:
  private val ivyPath: Option[String] = {
    sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse(
      Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath))
  }

  private def buildConf(extraConf: Map[String, String]) = {
    lazy val warehousePath = Utils.createTempDir()
    lazy val metastorePath = Utils.createTempDir()
    metastorePath.delete()
    extraConf ++ Map(
      "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true",
      "hive.metastore.warehouse.dir" -> warehousePath.toString)
  }

  // for testing only
  def buildClient(
      version: String,
      hadoopConf: Configuration,
      extraConf: Map[String, String] = Map.empty): HiveClient = {
    IsolatedClientLoader.forVersion(
      hiveMetastoreVersion = version,
      hadoopVersion = VersionInfo.getVersion,
      sparkConf = new SparkConf(),
      hadoopConf = hadoopConf,
      config = buildConf(extraConf),
      ivyPath = ivyPath).createClient()
  }
} 
Example 199
Source File: SparkSQLEnv.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.PrintStream

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
import org.apache.spark.util.Utils


  def stop() {
    logDebug("Shutting down Spark SQL Environment")
    // Stop the SparkContext
    if (SparkSQLEnv.sparkContext != null) {
      sparkContext.stop()
      sparkContext = null
      sqlContext = null
    }
  }
} 
Example 200
Source File: HiveMetastoreLazyInitializationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.util.Utils

class HiveMetastoreLazyInitializationSuite extends SparkFunSuite {

  test("lazily initialize Hive client") {
    val spark = SparkSession.builder()
      .appName("HiveMetastoreLazyInitializationSuite")
      .master("local[2]")
      .enableHiveSupport()
      .config("spark.hadoop.hive.metastore.uris", "thrift://127.0.0.1:11111")
      .getOrCreate()
    val originalLevel = org.apache.log4j.Logger.getRootLogger().getLevel
    try {
      // Avoid outputting a lot of expected warning logs
      spark.sparkContext.setLogLevel("error")

      // We should be able to run Spark jobs without Hive client.
      assert(spark.sparkContext.range(0, 1).count() === 1)

      // Make sure that we are not using the local derby metastore.
      val exceptionString = Utils.exceptionString(intercept[AnalysisException] {
        spark.sql("show tables")
      })
      for (msg <- Seq(
        "show tables",
        "Could not connect to meta store",
        "org.apache.thrift.transport.TTransportException",
        "Connection refused")) {
        exceptionString.contains(msg)
      }
    } finally {
      spark.sparkContext.setLogLevel(originalLevel.toString)
      spark.stop()
    }
  }
}