org.apache.spark.util.TaskCompletionListener Scala Examples

The following examples show how to use org.apache.spark.util.TaskCompletionListener. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: RiakWriterTaskCompletionListener.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.riak

import org.apache.spark.TaskContext
import org.apache.spark.executor.{DataWriteMethod, OutputMetrics}
import org.apache.spark.util.TaskCompletionListener

class RiakWriterTaskCompletionListener(recordsWritten: Long) extends TaskCompletionListener{

  override def onTaskCompletion(context: TaskContext): Unit = {
    val metrics = OutputMetrics(DataWriteMethod.Hadoop)
    metrics.setRecordsWritten(recordsWritten)
    context.taskMetrics().outputMetrics = Some(metrics)
  }

}

object RiakWriterTaskCompletionListener {
  def apply(recordsWritten: Long) = new RiakWriterTaskCompletionListener(recordsWritten)
} 
Example 2
Source File: TaskContextImpl.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

import scala.collection.mutable.ArrayBuffer

private[spark] class TaskContextImpl(
    val stageId: Int,
    val partitionId: Int,
    override val taskAttemptId: Long,
    override val attemptNumber: Int,
    val runningLocally: Boolean = false,
    val taskMetrics: TaskMetrics = TaskMetrics.empty)
  extends TaskContext
  with Logging {

  // For backwards-compatibility; this method is now deprecated as of 1.3.0.
  override def attemptId(): Long = taskAttemptId

  // List of callback functions to execute when the task completes.
  @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]

  // Whether the corresponding task has been killed.
  @volatile private var interrupted: Boolean = false

  // Whether the task has completed.
  @volatile private var completed: Boolean = false

  override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
    onCompleteCallbacks += listener
    this
  }

  override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f(context)
    }
    this
  }

  @deprecated("use addTaskCompletionListener", "1.1.0")
  override def addOnCompleteCallback(f: () => Unit) {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f()
    }
  }

  
  private[spark] def markInterrupted(): Unit = {
    interrupted = true
  }

  override def isCompleted(): Boolean = completed

  override def isRunningLocally(): Boolean = runningLocally

  override def isInterrupted(): Boolean = interrupted
} 
Example 3
Source File: TaskContextSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.mockito.Mockito._
import org.mockito.Matchers.any

import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}


class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {

  test("calls TaskCompletionListener after failure") {
    TaskContextSuite.completed = false
    sc = new SparkContext("local", "test")
    val rdd = new RDD[String](sc, List()) {
      override def getPartitions = Array[Partition](StubPartition(0))
      override def compute(split: Partition, context: TaskContext) = {
        context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
        sys.error("failed")
      }
    }
    val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
    val func = (c: TaskContext, i: Iterator[String]) => i.next()
    val task = new ResultTask[String, String](
      0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
    intercept[RuntimeException] {
      task.run(0, 0)
    }
    assert(TaskContextSuite.completed === true)
  }

  test("all TaskCompletionListeners should be called even if some fail") {
    val context = new TaskContextImpl(0, 0, 0, 0)
    val listener = mock(classOf[TaskCompletionListener])
    context.addTaskCompletionListener(_ => throw new Exception("blah"))
    context.addTaskCompletionListener(listener)
    context.addTaskCompletionListener(_ => throw new Exception("blah"))

    intercept[TaskCompletionListenerException] {
      context.markTaskCompleted()
    }

    verify(listener, times(1)).onTaskCompletion(any())
  }

  test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
    sc = new SparkContext("local[1,2]", "test")  // use maxRetries = 2 because we test failed tasks
    // Check that attemptIds are 0 for all tasks' initial attempts
    val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      Seq(TaskContext.get().attemptNumber).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0))

    // Test a job with failed tasks
    val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      val attemptId = TaskContext.get().attemptNumber
      if (iter.next() == 1 && attemptId == 0) {
        throw new Exception("First execution of task failed")
      }
      Seq(attemptId).iterator
    }.collect()
    assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
  }

  test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") {
    sc = new SparkContext("local", "test")
    val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter =>
      Seq(TaskContext.get().attemptId).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0, 1, 2, 3))
  }
}

private object TaskContextSuite {
  @volatile var completed = false
}

private case class StubPartition(index: Int) extends Partition 
Example 4
Source File: CarbonTaskCompletionListener.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.carbondata.execution.datasources.tasklisteners

import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.RecordReaderIterator
import org.apache.spark.util.TaskCompletionListener

import org.apache.carbondata.common.logging.LogServiceFactory
import org.apache.carbondata.core.memory.UnsafeMemoryManager
import org.apache.carbondata.core.util.{DataTypeUtil, ThreadLocalTaskInfo}
import org.apache.carbondata.hadoop.internal.ObjectArrayWritable


trait CarbonCompactionTaskCompletionListener extends TaskCompletionListener

case class CarbonQueryTaskCompletionListenerImpl(iter: RecordReaderIterator[InternalRow],
    freeMemory: Boolean = false) extends CarbonQueryTaskCompletionListener {
  override def onTaskCompletion(context: TaskContext): Unit = {
    if (iter != null) {
      try {
        iter.close()
      } catch {
        case e: Exception =>
          LogServiceFactory.getLogService(this.getClass.getCanonicalName).error(e)
      }
    }
    if (freeMemory) {
      UnsafeMemoryManager.INSTANCE
        .freeMemoryAll(ThreadLocalTaskInfo.getCarbonTaskInfo.getTaskId)
      ThreadLocalTaskInfo.clearCarbonTaskInfo()
    }
    DataTypeUtil.clearFormatter()
  }
}

case class CarbonLoadTaskCompletionListenerImpl(recordWriter: RecordWriter[NullWritable,
  ObjectArrayWritable],
    taskAttemptContext: TaskAttemptContext) extends CarbonLoadTaskCompletionListener {

  override def onTaskCompletion(context: TaskContext): Unit = {
    try {
      recordWriter.close(taskAttemptContext)
    } finally {
      UnsafeMemoryManager.INSTANCE
        .freeMemoryAll(ThreadLocalTaskInfo.getCarbonTaskInfo.getTaskId)
      ThreadLocalTaskInfo.clearCarbonTaskInfo()
      DataTypeUtil.clearFormatter()
    }
  }
} 
Example 5
Source File: TaskContextImpl.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

import scala.collection.mutable.ArrayBuffer

private[spark] class TaskContextImpl(
    val stageId: Int,
    val partitionId: Int,
    override val taskAttemptId: Long,
    override val attemptNumber: Int,
    override val taskMemoryManager: TaskMemoryManager,
    val runningLocally: Boolean = false,
    val taskMetrics: TaskMetrics = TaskMetrics.empty)
  extends TaskContext
  with Logging {

  // For backwards-compatibility; this method is now deprecated as of 1.3.0.
  override def attemptId(): Long = taskAttemptId

  // List of callback functions to execute when the task completes.
  @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]

  // Whether the corresponding task has been killed.
  @volatile private var interrupted: Boolean = false

  // Whether the task has completed.
  @volatile private var completed: Boolean = false

  override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
    onCompleteCallbacks += listener
    this
  }

  override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f(context)
    }
    this
  }

  @deprecated("use addTaskCompletionListener", "1.1.0")
  override def addOnCompleteCallback(f: () => Unit) {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f()
    }
  }

  
  private[spark] def markInterrupted(): Unit = {
    interrupted = true
  }

  override def isCompleted(): Boolean = completed

  override def isRunningLocally(): Boolean = runningLocally

  override def isInterrupted(): Boolean = interrupted
} 
Example 6
Source File: TaskContextSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.mockito.Mockito._
import org.mockito.Matchers.any

import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}


class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  test("calls TaskCompletionListener after failure") {
    TaskContextSuite.completed = false
    sc = new SparkContext("local", "test")
    val rdd = new RDD[String](sc, List()) {
      override def getPartitions = Array[Partition](StubPartition(0))
      override def compute(split: Partition, context: TaskContext) = {
        context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
        sys.error("failed")
      }
    }
    val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
    val func = (c: TaskContext, i: Iterator[String]) => i.next()
    val task = new ResultTask[String, String](
      0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
    intercept[RuntimeException] {
      task.run(0, 0)
    }
    assert(TaskContextSuite.completed === true)
  }

  test("all TaskCompletionListeners should be called even if some fail") {
    val context = new TaskContextImpl(0, 0, 0, 0, null)
    val listener = mock(classOf[TaskCompletionListener])
    context.addTaskCompletionListener(_ => throw new Exception("blah"))
    context.addTaskCompletionListener(listener)
    context.addTaskCompletionListener(_ => throw new Exception("blah"))

    intercept[TaskCompletionListenerException] {
      context.markTaskCompleted()
    }

    verify(listener, times(1)).onTaskCompletion(any())
  }

  test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
    sc = new SparkContext("local[1,2]", "test")  // use maxRetries = 2 because we test failed tasks
    // Check that attemptIds are 0 for all tasks' initial attempts
    val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      Seq(TaskContext.get().attemptNumber).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0))

    // Test a job with failed tasks
    val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      val attemptId = TaskContext.get().attemptNumber
      if (iter.next() == 1 && attemptId == 0) {
        throw new Exception("First execution of task failed")
      }
      Seq(attemptId).iterator
    }.collect()
    assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
  }

  test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") {
    sc = new SparkContext("local", "test")
    val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter =>
      Seq(TaskContext.get().attemptId).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0, 1, 2, 3))
  }
}

private object TaskContextSuite {
  @volatile var completed = false
}

private case class StubPartition(index: Int) extends Partition 
Example 7
Source File: TaskContextSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.mockito.Mockito._
import org.mockito.Matchers.any

import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import org.apache.spark.metrics.source.JvmSource


class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  test("provide metrics sources") {//提供测量数源
    val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile
    val conf = new SparkConf(loadDefaults = false)
      .set("spark.metrics.conf", filePath)
    sc = new SparkContext("local", "test", conf)
    val rdd = sc.makeRDD(1 to 1)
    val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => {
      tc.getMetricsSources("jvm").count {
        case source: JvmSource => true
        case _ => false
      }
    }).sum
    assert(result > 0)
  }

  test("calls TaskCompletionListener after failure") {//调用taskcompletionlistener失败后
    TaskContextSuite.completed = false
    sc = new SparkContext("local", "test")
    val rdd = new RDD[String](sc, List()) {
      override def getPartitions = Array[Partition](StubPartition(0))
      override def compute(split: Partition, context: TaskContext) = {
        context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
        sys.error("failed")
      }
    }
    val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
    val func = (c: TaskContext, i: Iterator[String]) => i.next()
    val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
    val task = new ResultTask[String, String](
      0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
    intercept[RuntimeException] {
      task.run(0, 0, null)
    }
    assert(TaskContextSuite.completed === true)
  }
  //应该被称为即使一些失败
  test("all TaskCompletionListeners should be called even if some fail") {
    val context = TaskContext.empty()
    val listener = mock(classOf[TaskCompletionListener])
    context.addTaskCompletionListener(_ => throw new Exception("blah"))
    context.addTaskCompletionListener(listener)
    context.addTaskCompletionListener(_ => throw new Exception("blah"))

    intercept[TaskCompletionListenerException] {
      context.markTaskCompleted()
    }

    verify(listener, times(1)).onTaskCompletion(any())
  }

  test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
    //使用maxretries = 2因为我们测试失败的任务
    sc = new SparkContext("local[1,2]", "test")  // use maxRetries = 2 because we test failed tasks
    // Check that attemptIds are 0 for all tasks' initial attempts
    //检查最初尝试attemptids 0所有的任务
    val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      Seq(TaskContext.get().attemptNumber).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0))

    // Test a job with failed tasks
    //测试一个任务失败的任务
    val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      val attemptId = TaskContext.get().attemptNumber
      if (iter.next() == 1 && attemptId == 0) {
        throw new Exception("First execution of task failed")
      }
      Seq(attemptId).iterator
    }.collect()
    assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
  }
  //返回taskattemptid向后兼容性
  test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") {
    sc = new SparkContext("local", "test")
    val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter =>
      Seq(TaskContext.get().attemptId).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0, 1, 2, 3))
  }
}

private object TaskContextSuite {
  @volatile var completed = false
}

private case class StubPartition(index: Int) extends Partition 
Example 8
Source File: TaskContextImpl.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

private[spark] class TaskContextImpl(
    val stageId: Int,
    val partitionId: Int,
    override val taskAttemptId: Long,
    override val attemptNumber: Int,
    override val taskMemoryManager: TaskMemoryManager,
    @transient private val metricsSystem: MetricsSystem,
    internalAccumulators: Seq[Accumulator[Long]],
    val runningLocally: Boolean = false,
    val taskMetrics: TaskMetrics = TaskMetrics.empty)
  extends TaskContext
  with Logging {

  // For backwards-compatibility; this method is now deprecated as of 1.3.0.
  override def attemptId(): Long = taskAttemptId

  // List of callback functions to execute when the task completes.
  @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]

  // Whether the corresponding task has been killed.
  @volatile private var interrupted: Boolean = false

  // Whether the task has completed.
  @volatile private var completed: Boolean = false

  override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
    onCompleteCallbacks += listener
    this
  }

  override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f(context)
    }
    this
  }

  @deprecated("use addTaskCompletionListener", "1.1.0")
  override def addOnCompleteCallback(f: () => Unit) {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f()
    }
  }

  
  private[spark] def markInterrupted(): Unit = {
    interrupted = true
  }

  override def isCompleted(): Boolean = completed

  override def isRunningLocally(): Boolean = runningLocally

  override def isInterrupted(): Boolean = interrupted

  override def getMetricsSources(sourceName: String): Seq[Source] =
    metricsSystem.getSourcesByName(sourceName)

  @transient private val accumulators = new HashMap[Long, Accumulable[_, _]]

  private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
    accumulators(a.id) = a
  }

  private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
    accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
  }

  private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
    accumulators.mapValues(_.localValue).toMap
  }

  //private[spark]
  override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
    // Explicitly register internal accumulators here because these are
    // not captured in the task closure and are already deserialized
    internalAccumulators.foreach(registerAccumulator)
    internalAccumulators.map { a => (a.name.get, a) }.toMap
  }
} 
Example 9
Source File: TaskContextSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.mockito.Mockito._
import org.mockito.Matchers.any

import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import org.apache.spark.metrics.source.JvmSource


class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  test("provide metrics sources") {
    val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile
    val conf = new SparkConf(loadDefaults = false)
      .set("spark.metrics.conf", filePath)
    sc = new SparkContext("local", "test", conf)
    val rdd = sc.makeRDD(1 to 1)
    val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => {
      tc.getMetricsSources("jvm").count {
        case source: JvmSource => true
        case _ => false
      }
    }).sum
    assert(result > 0)
  }

  test("calls TaskCompletionListener after failure") {
    TaskContextSuite.completed = false
    sc = new SparkContext("local", "test")
    val rdd = new RDD[String](sc, List()) {
      override def getPartitions = Array[Partition](StubPartition(0))
      override def compute(split: Partition, context: TaskContext) = {
        context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
        sys.error("failed")
      }
    }
    val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
    val func = (c: TaskContext, i: Iterator[String]) => i.next()
    val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
    val task = new ResultTask[String, String](
      0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
    intercept[RuntimeException] {
      task.run(0, 0, null)
    }
    assert(TaskContextSuite.completed === true)
  }

  test("all TaskCompletionListeners should be called even if some fail") {
    val context = TaskContext.empty()
    val listener = mock(classOf[TaskCompletionListener])
    context.addTaskCompletionListener(_ => throw new Exception("blah"))
    context.addTaskCompletionListener(listener)
    context.addTaskCompletionListener(_ => throw new Exception("blah"))

    intercept[TaskCompletionListenerException] {
      context.markTaskCompleted()
    }

    verify(listener, times(1)).onTaskCompletion(any())
  }

  test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
    sc = new SparkContext("local[1,2]", "test")  // use maxRetries = 2 because we test failed tasks
    // Check that attemptIds are 0 for all tasks' initial attempts
    val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      Seq(TaskContext.get().attemptNumber).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0))

    // Test a job with failed tasks
    val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      val attemptId = TaskContext.get().attemptNumber
      if (iter.next() == 1 && attemptId == 0) {
        throw new Exception("First execution of task failed")
      }
      Seq(attemptId).iterator
    }.collect()
    assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
  }

  test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") {
    sc = new SparkContext("local", "test")
    val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter =>
      Seq(TaskContext.get().attemptId).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0, 1, 2, 3))
  }
}

private object TaskContextSuite {
  @volatile var completed = false
}

private case class StubPartition(index: Int) extends Partition