com.holdenkarau.spark.testing.DataFrameSuiteBase Scala Examples
The following examples show how to use com.holdenkarau.spark.testing.DataFrameSuiteBase.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: NestedCaseClassesTest.scala From cleanframes with Apache License 2.0 | 8 votes |
package cleanframes import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.spark.sql.functions import org.scalatest.{FlatSpec, Matchers} class NestedCaseClassesTest extends FlatSpec with Matchers with DataFrameSuiteBase { "Cleaner" should "compile and use a custom transformer for a custom type" in { import cleanframes.syntax._ // to use `.clean` import spark.implicits._ // define test data for a dataframe val input = Seq( // @formatter:off ("1", "1", "1", "1", null), (null, "2", null, "2", "corrupted"), ("corrupted", null, "corrupted", null, "true"), ("4", "corrupted", "4", "4", "false"), ("5", "5", "5", "corrupted", "false"), ("6", "6", "6", "6", "true") // @formatter:on ) // give column names that are known to you .toDF("col1", "col2", "col3", "col4", "col5") // import standard functions for conversions shipped with the library import cleanframes.instances.all._ // !important: you need to give a new structure to allow to access sub elements val renamed = input.select( functions.struct( input.col("col1") as "a_col_1", input.col("col2") as "a_col_2" ) as "a", functions.struct( input.col("col3") as "b_col_1", input.col("col4") as "b_col_2" ) as "b", input.col("col5") as "c" ) val result = renamed.clean[AB] .as[AB] .collect result should { contain theSameElementsAs Seq( // @formatter:off AB( A(Some(1), Some(1)), B(Some(1), Some(1.0)), Some(false)), AB( A(None, Some(2)), B(None, Some(2.0)), Some(false)), AB( A(None, None), B(None, None), Some(true)), AB( A(Some(4), None), B(Some(4), Some(4.0)), Some(false)), AB( A(Some(5), Some(5)), B(Some(5), None), Some(false)), AB( A(Some(6), Some(6)), B(Some(6), Some(6.0)), Some(true)) // @formatter:on ) } } } case class A(a_col_1: Option[Int], a_col_2: Option[Float]) case class B(b_col_1: Option[Float], b_col_2: Option[Double]) case class AB(a: A, b: B, c: Option[Boolean])
Example 2
Source File: SparkPFASuiteBase.scala From aardpfark with Apache License 2.0 | 6 votes |
package com.ibm.aardpfark.pfa import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.spark.SparkConf import org.apache.spark.ml.Transformer import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.udf import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.scalactic.Equality import org.scalatest.FunSuite abstract class SparkPFASuiteBase extends FunSuite with DataFrameSuiteBase with PFATestUtils { val sparkTransformer: Transformer val input: Array[String] val expectedOutput: Array[String] val sparkConf = new SparkConf(). setMaster("local[*]"). setAppName("test"). set("spark.ui.enabled", "false"). set("spark.app.id", appID). set("spark.driver.host", "localhost") override lazy val spark = SparkSession.builder().config(sparkConf).getOrCreate() override val reuseContextIfPossible = true // Converts column containing a vector to an array def withColumnAsArray(df: DataFrame, colName: String) = { val vecToArray = udf { v: Vector => v.toArray } df.withColumn(colName, vecToArray(df(colName))) } def withColumnAsArray(df: DataFrame, first: String, others: String*) = { val vecToArray = udf { v: Vector => v.toArray } var result = df.withColumn(first, vecToArray(df(first))) others.foreach(c => result = result.withColumn(c, vecToArray(df(c)))) result } // Converts column containing a vector to a sparse vector represented as a map def getColumnAsSparseVectorMap(df: DataFrame, colName: String) = { val vecToMap = udf { v: Vector => v.toSparse.indices.map(i => (i.toString, v(i))).toMap } df.withColumn(colName, vecToMap(df(colName))) } } abstract class Result object ApproxEquality extends ApproxEquality trait ApproxEquality { import org.scalactic.Tolerance._ import org.scalactic.TripleEquals._ implicit val seqApproxEq: Equality[Seq[Double]] = new Equality[Seq[Double]] { override def areEqual(a: Seq[Double], b: Any): Boolean = { b match { case d: Seq[Double] => a.zip(d).forall { case (l, r) => l === r +- 0.001 } case _ => false } } } implicit val vectorApproxEq: Equality[Vector] = new Equality[Vector] { override def areEqual(a: Vector, b: Any): Boolean = { b match { case v: Vector => a.toArray.zip(v.toArray).forall { case (l, r) => l === r +- 0.001 } case _ => false } } } }
Example 3
Source File: ColumnPruningSuite.scala From spark-exasol-connector with Apache License 2.0 | 5 votes |
package com.exasol.spark import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.scalatest.funsuite.AnyFunSuite class ColumnPruningSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase { test("returns only required columns in query") { createDummyTable() val df = spark.read .format("com.exasol.spark") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE") .load() .select("city") assert(df.columns.size === 1) assert(df.columns.head === "city") val result = df.collect().map(x => x.getString(0)).toSet assert(result === Set("Berlin", "Paris", "Lisbon")) } }
Example 4
Source File: LongReadsTestSuite.scala From bdg-sequila with Apache License 2.0 | 5 votes |
package org.biodatageeks.sequila.tests.coverage import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext} import org.apache.spark.sql.{SequilaSession, SparkSession} import org.biodatageeks.sequila.utils.{Columns, InternalParams, SequilaRegister} import org.scalatest.{BeforeAndAfter, FunSuite} class LongReadsTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndAfter with SharedSparkContext { val bamPath: String = getClass.getResource("/nanopore_guppy_slice.bam").getPath val splitSize = 30000 val tableNameBAM = "reads" before { System.setSecurityManager(null) spark.sql(s"DROP TABLE IF EXISTS $tableNameBAM") spark.sql(s""" |CREATE TABLE $tableNameBAM |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource |OPTIONS(path "$bamPath") | """.stripMargin) } test("BAM - Nanopore with guppy basecaller") { val session: SparkSession = SequilaSession(spark) SequilaRegister.register(session) session.sparkContext .setLogLevel("WARN") val bdg = session.sql(s"SELECT * FROM ${tableNameBAM}") assert(bdg.count() === 150) } test("BAM - coverage - Nanopore with guppy basecaller") { spark.sqlContext.setConf(InternalParams.InputSplitSize, (splitSize * 10).toString) val session2: SparkSession = SequilaSession(spark) SequilaRegister.register(session2) val query = s"""SELECT ${Columns.CONTIG}, ${Columns.START}, ${Columns.COVERAGE} FROM bdg_coverage('$tableNameBAM','nanopore_guppy_slice','bases') order by ${Columns.CONTIG},${Columns.START},${Columns.END} """.stripMargin val covMultiPartitionDF = session2.sql(query) //covMultiPartitionDF.coalesce(1).write.mode("overwrite").option("delimiter", "\t").csv("/Users/aga/workplace/multiPart.csv") assert(covMultiPartitionDF.count() == 45620) // total count check 45620<---> 45842 assert(covMultiPartitionDF.filter(s"${Columns.COVERAGE}== 0").count == 0) assert( covMultiPartitionDF .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5010515") .first() .getShort(2) == 1) // value check [first element] assert( covMultiPartitionDF .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5022667") .first() .getShort(2) == 15) // value check [partition boundary] assert( covMultiPartitionDF .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5036398") .first() .getShort(2) == 14) // value check [partition boundary] assert( covMultiPartitionDF .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5056356") .first() .getShort(2) == 1) // value check [last element] } }
Example 5
Source File: JoinOrderTestSuite.scala From bdg-sequila with Apache License 2.0 | 5 votes |
package org.biodatageeks.sequila.tests.rangejoins import java.io.{OutputStreamWriter, PrintWriter} import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{ IntegerType, StringType, StructField, StructType } import org.bdgenomics.utils.instrumentation.{ Metrics, MetricsListener, RecordedMetrics } import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim import org.scalatest.{BeforeAndAfter, FunSuite} class JoinOrderTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndAfter with SharedSparkContext { val schema = StructType( Seq(StructField("chr", StringType), StructField("start", IntegerType), StructField("end", IntegerType))) val metricsListener = new MetricsListener(new RecordedMetrics()) val writer = new PrintWriter(new OutputStreamWriter(System.out)) before { System.setSecurityManager(null) spark.experimental.extraStrategies = new IntervalTreeJoinStrategyOptim( spark) :: Nil Metrics.initialize(sc) val rdd1 = sc .textFile(getClass.getResource("/refFlat.txt.bz2").getPath) .map(r => r.split('\t')) .map( r => Row( r(2).toString, r(4).toInt, r(5).toInt )) val ref = spark.createDataFrame(rdd1, schema) ref.createOrReplaceTempView("ref") val rdd2 = sc .textFile(getClass.getResource("/snp150Flagged.txt.bz2").getPath) .map(r => r.split('\t')) .map( r => Row( r(1).toString, r(2).toInt, r(3).toInt )) val snp = spark .createDataFrame(rdd2, schema) snp.createOrReplaceTempView("snp") } test("Join order - broadcasting snp table") { spark.sqlContext.setConf("spark.biodatageeks.rangejoin.useJoinOrder", "true") val query = s""" |SELECT snp.*,ref.* FROM ref JOIN snp |ON (ref.chr=snp.chr AND snp.end>=ref.start AND snp.start<=ref.end) """.stripMargin assert(spark.sql(query).count === 616404L) } test("Join order - broadcasting ref table") { spark.sqlContext.setConf("spark.biodatageeks.rangejoin.useJoinOrder", "true") val query = s""" |SELECT snp.*,ref.* FROM snp JOIN ref |ON (ref.chr=snp.chr AND snp.end>=ref.start AND snp.start<=ref.end) """.stripMargin assert(spark.sql(query).count === 616404L) } after { Metrics.print(writer, Some(metricsListener.metrics.sparkMetrics.stageTimes)) writer.flush() Metrics.stopRecording() } }
Example 6
Source File: FeatureCountsTestSuite.scala From bdg-sequila with Apache License 2.0 | 5 votes |
package org.biodatageeks.sequila.tests.rangejoins import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext} import htsjdk.samtools.ValidationStringency import org.apache.hadoop.io.LongWritable import org.biodatageeks.sequila.apps.FeatureCounts.Region import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim import org.biodatageeks.sequila.utils.{Columns, DataQualityFuncs} import org.scalatest.{BeforeAndAfter, FunSuite} import org.seqdoop.hadoop_bam.util.SAMHeaderReader import org.seqdoop.hadoop_bam.{BAMInputFormat, SAMRecordWritable} class FeatureCountsTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndAfter with SharedSparkContext { before { System.setSecurityManager(null) spark.experimental.extraStrategies = new IntervalTreeJoinStrategyOptim( spark) :: Nil } test("Feature counts for chr1:20138-20294") { val query = s""" | SELECT count(*),targets.${Columns.CONTIG},targets.${Columns.START},targets.${Columns.END} | FROM reads JOIN targets |ON ( | targets.${Columns.CONTIG}=reads.${Columns.CONTIG} | AND | reads.${Columns.END} >= targets.${Columns.START} | AND | reads.${Columns.START} <= targets.${Columns.END} |) | GROUP BY targets.${Columns.CONTIG},targets.${Columns.START},targets.${Columns.END} | HAVING ${Columns.CONTIG}='1' AND ${Columns.START} = 20138 AND ${Columns.END} = 20294""".stripMargin spark.sparkContext.hadoopConfiguration.set( SAMHeaderReader.VALIDATION_STRINGENCY_PROPERTY, ValidationStringency.SILENT.toString) val alignments = spark.sparkContext .newAPIHadoopFile[LongWritable, SAMRecordWritable, BAMInputFormat]( getClass.getResource("/NA12878.slice.bam").getPath) .map(_._2.get) .map(r => Region(DataQualityFuncs.cleanContig(r.getContig), r.getStart, r.getEnd)) val reads = spark.sqlContext .createDataFrame(alignments) .withColumnRenamed("contigName", Columns.CONTIG) .withColumnRenamed("start", Columns.START) .withColumnRenamed("end", Columns.END) reads.createOrReplaceTempView("reads") val targets = spark.sqlContext .createDataFrame(Array(Region("1", 20138, 20294))) .withColumnRenamed("contigName", Columns.CONTIG) .withColumnRenamed("start", Columns.START) .withColumnRenamed("end", Columns.END) targets.createOrReplaceTempView("targets") spark.sql(query).explain(false) assert(spark.sql(query).first().getLong(0) === 1484L) } }
Example 7
Source File: PileupTestBase.scala From bdg-sequila with Apache License 2.0 | 5 votes |
package org.biodatageeks.sequila.tests.pileup import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext} import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.types.{IntegerType, ShortType, StringType, StructField, StructType} import org.scalatest.{BeforeAndAfter, FunSuite} class PileupTestBase extends FunSuite with DataFrameSuiteBase with BeforeAndAfter with SharedSparkContext{ val sampleId = "NA12878.multichrom.md" val samResPath: String = getClass.getResource("/multichrom/mdbam/samtools.pileup").getPath val referencePath: String = getClass.getResource("/reference/Homo_sapiens_assembly18_chr1_chrM.small.fasta").getPath val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath val cramPath : String = getClass.getResource(s"/multichrom/mdcram/${sampleId}.cram").getPath val tableName = "reads_bam" val tableNameCRAM = "reads_cram" val schema: StructType = StructType( List( StructField("contig", StringType, nullable = true), StructField("position", IntegerType, nullable = true), StructField("reference", StringType, nullable = true), StructField("coverage", ShortType, nullable = true), StructField("pileup", StringType, nullable = true), StructField("quality", StringType, nullable = true) ) ) before { System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator") spark .conf.set("spark.sql.shuffle.partitions",1) //FIXME: In order to get orderBy in Samtools tests working - related to exchange partitions stage spark.sql(s"DROP TABLE IF EXISTS $tableName") spark.sql( s""" |CREATE TABLE $tableName |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource |OPTIONS(path "$bamPath") | """.stripMargin) spark.sql(s"DROP TABLE IF EXISTS $tableNameCRAM") spark.sql( s""" |CREATE TABLE $tableNameCRAM |USING org.biodatageeks.sequila.datasources.BAM.CRAMDataSource |OPTIONS(path "$cramPath", refPath "$referencePath" ) | """.stripMargin) val mapToString = (map: Map[Byte, Short]) => { if (map == null) "null" else map.map({ case (k, v) => k.toChar -> v}).mkString.replace(" -> ", ":") } val byteToString = ((byte: Byte) => byte.toString) spark.udf.register("mapToString", mapToString) spark.udf.register("byteToString", byteToString) } }
Example 8
Source File: VCFDataSourceTestSuite.scala From bdg-sequila with Apache License 2.0 | 5 votes |
package org.biodatageeks.sequila.tests.datasources import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext} import org.biodatageeks.sequila.utils.Columns import org.scalatest.{BeforeAndAfter, FunSuite} class VCFDataSourceTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndAfter with SharedSparkContext { val vcfPath: String = getClass.getResource("/vcf/test.vcf").getPath val tableNameVCF = "variants" before { spark.sql(s"DROP TABLE IF EXISTS $tableNameVCF") spark.sql(s""" |CREATE TABLE $tableNameVCF |USING org.biodatageeks.sequila.datasources.VCF.VCFDataSource |OPTIONS(path "$vcfPath") | """.stripMargin) } test("VCF - Row count VCFDataSource") { val query = s"SELECT * FROM $tableNameVCF" spark .sql(query) .printSchema() assert( spark .sql(query) .first() .getString(0) === "20") assert(spark.sql(query).count() === 7L) } after { spark.sql(s"DROP TABLE IF EXISTS $tableNameVCF") } }
Example 9
Source File: BEDBaseTestSuite.scala From bdg-sequila with Apache License 2.0 | 5 votes |
package org.biodatageeks.sequila.tests.base import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext} import org.scalatest.{BeforeAndAfter, FunSuite} class BEDBaseTestSuite extends FunSuite with DataFrameSuiteBase with SharedSparkContext with BeforeAndAfter{ val bedPath: String = getClass.getResource("/bed/test.bed").getPath val tableNameBED = "targets" val bedSimplePath: String = getClass.getResource("/bed/simple.bed").getPath val tableNameSimpleBED = "simple_targets" before{ spark.sql(s"DROP TABLE IF EXISTS $tableNameBED") spark.sql(s""" |CREATE TABLE $tableNameBED |USING org.biodatageeks.sequila.datasources.BED.BEDDataSource |OPTIONS(path "$bedPath") | """.stripMargin) spark.sql(s"DROP TABLE IF EXISTS $tableNameSimpleBED") spark.sql(s""" |CREATE TABLE $tableNameSimpleBED |USING org.biodatageeks.sequila.datasources.BED.BEDDataSource |OPTIONS(path "$bedSimplePath") | """.stripMargin) } def after = { spark.sql(s"DROP TABLE IF EXISTS $tableNameBED") spark.sql(s"DROP TABLE IF EXISTS $tableNameSimpleBED") } }
Example 10
Source File: TestWithSpark.scala From ZparkIO with MIT License | 5 votes |
package com.leobenkel.zparkiotest import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.spark.SparkConf import org.scalatest.Suite trait TestWithSpark extends DataFrameSuiteBase { self: Suite => override protected val reuseContextIfPossible: Boolean = true override protected val enableHiveSupport: Boolean = false def enableSparkUI: Boolean = { false } final override def conf: SparkConf = { if (enableSparkUI) { super.conf .set("spark.ui.enabled", "true") .set("spark.ui.port", "4050") } else { super.conf } } }
Example 11
Source File: ExasolRelationSuite.scala From spark-exasol-connector with Apache License 2.0 | 5 votes |
package com.exasol.spark import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import com.exasol.spark.util.ExasolConnectionManager import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.mockito.Mockito._ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers import org.scalatestplus.mockito.MockitoSugar class ExasolRelationSuite extends AnyFunSuite with Matchers with MockitoSugar with DataFrameSuiteBase { test("buildScan returns RDD of empty Row-s when requiredColumns is empty (count pushdown)") { val query = "SELECT 1" val cntQuery = "SELECT COUNT(*) FROM (SELECT 1) A " val cnt = 5L val manager = mock[ExasolConnectionManager] when(manager.withCountQuery(cntQuery)).thenReturn(cnt) val relation = new ExasolRelation(spark.sqlContext, query, Option(new StructType), manager) val rdd = relation.buildScan() assert(rdd.isInstanceOf[RDD[Row]]) assert(rdd.partitions.size === 4) assert(rdd.count === cnt) verify(manager, times(1)).withCountQuery(cntQuery) } test("unhandledFilters should keep non-pushed filters") { val schema: StructType = new StructType() .add("a", BooleanType) .add("b", StringType) .add("c", IntegerType) val filters = Array[Filter]( LessThanOrEqual("c", "3"), EqualTo("b", "abc"), Not(EqualTo("a", false)) ) val nullFilters = Array(EqualNullSafe("b", "xyz")) val rel = new ExasolRelation(spark.sqlContext, "", Option(schema), null) assert(rel.unhandledFilters(filters) === Array.empty[Filter]) assert(rel.unhandledFilters(filters ++ nullFilters) === nullFilters) } }
Example 12
Source File: TypesSuite.scala From spark-exasol-connector with Apache License 2.0 | 5 votes |
package com.exasol.spark import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.spark.sql.types._ import org.scalatest.funsuite.AnyFunSuite class TypesSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase { test("converts Exasol types to Spark") { createAllTypesTable() val df = spark.read .format("com.exasol.spark") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_ALL_TYPES_TABLE") .load() val schemaTest = df.schema val schemaExpected = Map( "MYID" -> LongType, "MYTINYINT" -> ShortType, "MYSMALLINT" -> IntegerType, "MYBIGINT" -> DecimalType(36, 0), "MYDECIMALMAX" -> DecimalType(36, 36), "MYDECIMALSYSTEMDEFAULT" -> LongType, "MYNUMERIC" -> DecimalType(5, 2), "MYDOUBLE" -> DoubleType, "MYCHAR" -> StringType, "MYNCHAR" -> StringType, "MYLONGVARCHAR" -> StringType, "MYBOOLEAN" -> BooleanType, "MYDATE" -> DateType, "MYTIMESTAMP" -> TimestampType, "MYGEOMETRY" -> StringType, "MYINTERVAL" -> StringType ) val fields = schemaTest.toList fields.foreach(field => { assert(field.dataType === schemaExpected.get(field.name).get) }) } }
Example 13
Source File: PredicatePushdownSuite.scala From spark-exasol-connector with Apache License 2.0 | 5 votes |
package com.exasol.spark import java.sql.Timestamp import org.apache.spark.sql.functions.col import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.scalatest.funsuite.AnyFunSuite class PredicatePushdownSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase { test("with where clause build from filters: filter") { createDummyTable() import spark.implicits._ val df = spark.read .format("exasol") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE") .load() .filter($"id" < 3) .filter(col("city").like("Ber%")) .select("id", "city") val result = df.collect().map(x => (x.getLong(0), x.getString(1))).toSet assert(result.size === 1) assert(result === Set((1, "Berlin"))) } test("with where clause build from filters: createTempView and spark.sql") { createDummyTable() val df = spark.read .format("exasol") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE") .load() df.createOrReplaceTempView("myTable") val myDF = spark .sql("SELECT id, city FROM myTable WHERE id BETWEEN 1 AND 3 AND name < 'Japan'") val result = myDF.collect().map(x => (x.getLong(0), x.getString(1))).toSet assert(result.size === 2) assert(result === Set((1, "Berlin"), (2, "Paris"))) } test("date and timestamp should be read and filtered correctly") { import java.sql.Date createDummyTable() val df = spark.read .format("exasol") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT date_info, updated_at FROM $EXA_SCHEMA.$EXA_TABLE") .load() val minTimestamp = Timestamp.valueOf("2017-12-30 00:00:00.0000") val testDate = Date.valueOf("2017-12-31") val resultDate = df.collect().map(_.getDate(0)) assert(resultDate.contains(testDate)) val resultTimestamp = df.collect().map(_.getTimestamp(1)).map(x => x.after(minTimestamp)) assert(!resultTimestamp.contains(false)) val filteredByDateDF = df.filter(col("date_info") === testDate) assert(filteredByDateDF.count() === 1) val filteredByTimestampDF = df.filter(col("updated_at") < minTimestamp) assert(filteredByTimestampDF.count() === 0) } test("count should be performed successfully") { createDummyTable() val df = spark.read .format("exasol") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE") .load() val result = df.count() assert(result === 3) } }
Example 14
Source File: ReservedKeywordsSuite.scala From spark-exasol-connector with Apache License 2.0 | 5 votes |
package com.exasol.spark import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.scalatest.funsuite.AnyFunSuite class ReservedKeywordsSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase { val SCHEMA: String = "RESERVED_KEYWORDS" val TABLE: String = "TEST_TABLE" test("queries a table with reserved keyword") { createTable() val expected = Set("True", "False", "Checked") val df1 = spark.read .format("exasol") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"""SELECT "CONDITION" FROM $SCHEMA.$TABLE""") .load() assert(df1.collect().map(x => x(0)).toSet === expected) val df2 = spark.read .format("exasol") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT * FROM $SCHEMA.$TABLE") .load() .select("condition") assert(df2.collect().map(x => x(0)).toSet === expected) } ignore("queries a table with reserved keyword using where clause") { createTable() val df = spark.read .format("com.exasol.spark") .option("host", container.host) .option("port", s"${container.port}") .option("query", s"SELECT * FROM $SCHEMA.$TABLE") .load() .select(s""""CONDITION"""") .where(s""""CONDITION" LIKE '%Check%'""") assert(df.collect().map(x => x(0)).toSet === Set("Checked")) } def createTable(): Unit = exaManager.withExecute( Seq( s"DROP SCHEMA IF EXISTS $SCHEMA CASCADE", s"CREATE SCHEMA $SCHEMA", s"""|CREATE OR REPLACE TABLE $SCHEMA.$TABLE ( | ID INTEGER IDENTITY NOT NULL, | "CONDITION" VARCHAR(100) UTF8 |)""".stripMargin, s"""INSERT INTO $SCHEMA.$TABLE ("CONDITION") VALUES ('True')""", s"""INSERT INTO $SCHEMA.$TABLE ("CONDITION") VALUES ('False')""", s"""INSERT INTO $SCHEMA.$TABLE ("CONDITION") VALUES ('Checked')""", "commit" ) ) }
Example 15
Source File: StreamingKMeansSuite.scala From spark-structured-streaming-ml with Apache License 2.0 | 5 votes |
package com.highperformancespark.examples.structuredstreaming import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.linalg._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.execution.streaming.MemoryStream import org.scalatest.FunSuite import org.apache.log4j.{Level, Logger} case class TestRow(features: Vector) class StreamingKMeansSuite extends FunSuite with DataFrameSuiteBase { override def beforeAll(): Unit = { super.beforeAll() Logger.getLogger("org").setLevel(Level.OFF) } test("streaming model with one center should converge to true center") { import spark.implicits._ val k = 1 val dim = 5 val clusterSpread = 0.1 val seed = 63 // TODO: this test is very flaky. The centers do not converge for some // (most?) random seeds val (batches, trueCenters) = StreamingKMeansSuite.generateBatches(100, 80, k, dim, clusterSpread, seed) val inputStream = MemoryStream[TestRow] val ds = inputStream.toDS() val skm = new StreamingKMeans().setK(k).setRandomCenters(dim, 0.01) val query = skm.evilTrain(ds.toDF()) val streamingModels = batches.map { batch => inputStream.addData(batch) query.processAllAvailable() skm.getModel } // TODO: use spark's testing suite streamingModels.last.centers.zip(trueCenters).foreach { case (center, trueCenter) => val centers = center.toArray.mkString(",") val trueCenters = trueCenter.toArray.mkString(",") println(s"${centers} | ${trueCenters}") assert(center.toArray.zip(trueCenter.toArray).forall( x => math.abs(x._1 - x._2) < 0.1)) } query.stop() } def compareBatchAndStreaming( batchModel: KMeansModel, streamingModel: StreamingKMeansModel, validationData: DataFrame): Unit = { assert(batchModel.clusterCenters === streamingModel.centers) // TODO: implement prediction comparison } } object StreamingKMeansSuite { def generateBatches( numPoints: Int, numBatches: Int, k: Int, d: Int, r: Double, seed: Int, initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[TestRow]], Array[Vector]) = { val rand = scala.util.Random rand.setSeed(seed) val centers = initCenters match { case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian()))) case _ => initCenters } val data = (0 until numBatches).map { i => (0 until numPoints).map { idx => val center = centers(idx % k) val vec = Vectors.dense( Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r)) TestRow(vec) } } (data, centers) } }
Example 16
Source File: CustomSinkSuite.scala From spark-structured-streaming-ml with Apache License 2.0 | 5 votes |
package com.highperformancespark.examples.structuredstreaming import com.holdenkarau.spark.testing.DataFrameSuiteBase import scala.collection.mutable.ListBuffer import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.sql.{Dataset, DataFrame, Encoder, SQLContext} import org.apache.spark.sql.execution.streaming.MemoryStream class CustomSinkSuite extends FunSuite with DataFrameSuiteBase { test("really simple test of the custom sink") { import spark.implicits._ val input = MemoryStream[String] val doubled = input.toDS().map(x => x + " " + x) val formatName = ("com.highperformancespark.examples" + "structuredstreaming.CustomSinkCollectorProvider") val query = doubled.writeStream .queryName("testCustomSinkBasic") .format(formatName) .start() val inputData = List("hi", "holden", "bye", "pandas") input.addData(inputData) assert(query.isActive === true) query.processAllAvailable() assert(query.exception === None) assert(Pandas.results(0) === inputData.map(x => x + " " + x)) } } object Pandas{ val results = new ListBuffer[Seq[String]]() } class CustomSinkCollectorProvider extends ForeachDatasetSinkProvider { override def func(df: DataFrame) { val spark = df.sparkSession import spark.implicits._ Pandas.results += df.as[String].rdd.collect() } }
Example 17
Source File: EncryptedReadSuite.scala From spark-excel with Apache License 2.0 | 5 votes |
package com.crealytics.spark.excel import org.apache.spark.sql._ import org.apache.spark.sql.types._ import scala.collection.JavaConverters._ import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers object EncryptedReadSuite { val simpleSchema = StructType( List( StructField("A", DoubleType, true), StructField("B", DoubleType, true), StructField("C", DoubleType, true), StructField("D", DoubleType, true) ) ) val expectedData = List(Row(1.0d, 2.0d, 3.0d, 4.0d)).asJava } class EncryptedReadSuite extends AnyFunSpec with DataFrameSuiteBase with Matchers { import EncryptedReadSuite._ lazy val expected = spark.createDataFrame(expectedData, simpleSchema) def readFromResources(path: String, password: String, maxRowsInMemory: Option[Int] = None): DataFrame = { val url = getClass.getResource(path) val reader = spark.read .excel( dataAddress = s"Sheet1!A1", treatEmptyValuesAsNulls = true, workbookPassword = password, inferSchema = true ) val withMaxRows = maxRowsInMemory.fold(reader)(rows => reader.option("maxRowsInMemory", s"$rows")) withMaxRows.load(url.getPath) } describe("spark-excel") { it("should read encrypted xslx file") { val df = readFromResources("/spreadsheets/simple_encrypted.xlsx", "fooba") assertDataFrameEquals(expected, df) } it("should read encrypted xlsx file with maxRowsInMem=10") { val df = readFromResources("/spreadsheets/simple_encrypted.xlsx", "fooba", maxRowsInMemory = Some(10)) assertDataFrameEquals(expected, df) } it("should read encrypted xlsx file with maxRowsInMem=1") { val df = readFromResources("/spreadsheets/simple_encrypted.xlsx", "fooba", maxRowsInMemory = Some(1)) assertDataFrameEquals(expected, df) } it("should read encrypted xls file") { val df = readFromResources("/spreadsheets/simple_encrypted.xls", "fooba") assertDataFrameEquals(expected, df) } } }
Example 18
Source File: ProcessTest.scala From incubator-s2graph with Apache License 2.0 | 5 votes |
package org.apache.s2graph.s2jobs.task import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.scalatest.FunSuite class ProcessTest extends FunSuite with DataFrameSuiteBase { test("SqlProcess execute sql") { import spark.implicits._ val inputDF = Seq( ("a", "b", "friend"), ("a", "c", "friend"), ("a", "d", "friend") ).toDF("from", "to", "label") val inputMap = Map("input" -> inputDF) val sql = "SELECT * FROM input WHERE to = 'b'" val conf = TaskConf("test", "sql", Seq("input"), Map("sql" -> sql)) val process = new SqlProcess(conf) val rstDF = process.execute(spark, inputMap) val tos = rstDF.collect().map{ row => row.getAs[String]("to")} assert(tos.size == 1) assert(tos.head == "b") } }
Example 19
Source File: WalLogAggregateProcessTest.scala From incubator-s2graph with Apache License 2.0 | 5 votes |
package org.apache.s2graph.s2jobs.wal.process import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.s2graph.s2jobs.task.TaskConf import org.apache.s2graph.s2jobs.wal._ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} class WalLogAggregateProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase { import org.apache.s2graph.s2jobs.wal.TestData._ test("test entire process") { import spark.sqlContext.implicits._ val edges = spark.createDataset(walLogsLs).toDF() val processKey = "agg" val inputMap = Map(processKey -> edges) val taskConf = new TaskConf(name = "test", `type` = "agg", inputs = Seq(processKey), options = Map("maxNumOfEdges" -> "10") ) val job = new WalLogAggregateProcess(taskConf = taskConf) val processed = job.execute(spark, inputMap) processed.printSchema() processed.orderBy("from").as[WalLogAgg].collect().zip(aggExpected).foreach { case (real, expected) => real shouldBe expected } } }
Example 20
Source File: BuildTopFeaturesProcessTest.scala From incubator-s2graph with Apache License 2.0 | 5 votes |
package org.apache.s2graph.s2jobs.wal.process import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.s2graph.s2jobs.task.TaskConf import org.apache.s2graph.s2jobs.wal.DimValCountRank import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} class BuildTopFeaturesProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase { import org.apache.s2graph.s2jobs.wal.TestData._ test("test entire process.") { import spark.implicits._ val df = spark.createDataset(aggExpected).toDF() val taskConf = new TaskConf(name = "test", `type` = "test", inputs = Seq("input"), options = Map("minUserCount" -> "0") ) val job = new BuildTopFeaturesProcess(taskConf) val inputMap = Map("input" -> df) val featureDicts = job.execute(spark, inputMap) .orderBy("dim", "rank") .map(DimValCountRank.fromRow) .collect() featureDicts shouldBe featureDictExpected } }
Example 21
Source File: FilterTopFeaturesProcessTest.scala From incubator-s2graph with Apache License 2.0 | 5 votes |
package org.apache.s2graph.s2jobs.wal.process import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.s2graph.s2jobs.task.TaskConf import org.apache.s2graph.s2jobs.wal.transformer.DefaultTransformer import org.apache.s2graph.s2jobs.wal.{DimValCountRank, WalLogAgg} import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} class FilterTopFeaturesProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase { import org.apache.s2graph.s2jobs.wal.TestData._ test("test filterTopKsPerDim.") { import spark.implicits._ val featureDf = spark.createDataset(featureDictExpected).map { x => (x.dimVal.dim, x.dimVal.value, x.count, x.rank) }.toDF("dim", "value", "count", "rank") val maxRankPerDim = spark.sparkContext.broadcast(Map.empty[String, Int]) // filter nothing because all feature has rank < 10 val filtered = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 10) val real = filtered.orderBy("dim", "rank").map(DimValCountRank.fromRow).collect() real.zip(featureDictExpected).foreach { case (real, expected) => real shouldBe expected } // filter rank >= 2 val filtered2 = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 2) val real2 = filtered2.orderBy("dim", "rank").map(DimValCountRank.fromRow).collect() real2 shouldBe featureDictExpected.filter(_.rank < 2) } test("test filterWalLogAgg.") { import spark.implicits._ val walLogAgg = spark.createDataset(aggExpected) val featureDf = spark.createDataset(featureDictExpected).map { x => (x.dimVal.dim, x.dimVal.value, x.count, x.rank) }.toDF("dim", "value", "count", "rank") val maxRankPerDim = spark.sparkContext.broadcast(Map.empty[String, Int]) val transformers = Seq(DefaultTransformer(TaskConf.Empty)) // filter nothing. so input, output should be same. val featureFiltered = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 10) val validFeatureHashKeys = FilterTopFeaturesProcess.collectDistinctFeatureHashes(spark, featureFiltered) val validFeatureHashKeysBCast = spark.sparkContext.broadcast(validFeatureHashKeys) val real = FilterTopFeaturesProcess.filterWalLogAgg(spark, walLogAgg, transformers, validFeatureHashKeysBCast) .collect().sortBy(_.from) real.zip(aggExpected).foreach { case (real, expected) => real shouldBe expected } } test("test entire process. filter nothing.") { import spark.implicits._ val df = spark.createDataset(aggExpected).toDF() val featureDf = spark.createDataset(featureDictExpected).map { x => (x.dimVal.dim, x.dimVal.value, x.count, x.rank) }.toDF("dim", "value", "count", "rank") val inputKey = "input" val featureDictKey = "feature" // filter nothing since we did not specified maxRankPerDim and defaultMaxRank. val taskConf = new TaskConf(name = "test", `type` = "test", inputs = Seq(inputKey, featureDictKey), options = Map( "featureDict" -> featureDictKey, "walLogAgg" -> inputKey ) ) val inputMap = Map(inputKey -> df, featureDictKey -> featureDf) val job = new FilterTopFeaturesProcess(taskConf) val filtered = job.execute(spark, inputMap) .orderBy("from") .as[WalLogAgg] .collect() filtered.zip(aggExpected).foreach { case (real, expected) => real shouldBe expected } } }
Example 22
Source File: WordCountTest.scala From robin-sparkles with Apache License 2.0 | 5 votes |
package com.highperformancespark.robinsparkles //import com.highperformancespark.robinsparkles.listener._ import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.scalatest.FunSuite class WordCountTest extends FunSuite with DataFrameSuiteBase { test("word count with Stop Words Removed"){ // TODO: Add listener val linesRDD = sc.parallelize(Seq( "How happy was the panda? You ask.", "Panda is the most happy panda in all the#!?ing land!")) val stopWords: Set[String] = Set("a", "the", "in", "was", "there", "she", "he") val splitTokens: Array[Char] = "#%?!. ".toCharArray val wordCounts = WordCount.withStopWordsFiltered( linesRDD, splitTokens, stopWords) val wordCountsAsMap = wordCounts.collectAsMap() assert(!wordCountsAsMap.contains("the")) assert(!wordCountsAsMap.contains("?")) assert(!wordCountsAsMap.contains("#!?ing")) assert(wordCountsAsMap.contains("ing")) assert(wordCountsAsMap.get("panda").get.equals(3)) } }
Example 23
Source File: OptionalPrimitivesTest.scala From cleanframes with Apache License 2.0 | 5 votes |
package cleanframes import org.scalatest.{FlatSpec, Matchers} import com.holdenkarau.spark.testing.DataFrameSuiteBase class OptionalPrimitivesTest extends FlatSpec with Matchers with DataFrameSuiteBase { "Cleaner" should "transform data to concrete types if possible" in { import spark.implicits._ // to use `.toDF` and `.as` import cleanframes.syntax._ // to use `.clean` // define test data for a dataframe val input = Seq( // @formatter:off ("1", "1", "1", "1", "1", "1", "true"), ("corrupted", "2", "2", "2", "2", "2", "false"), ("3", "corrupted", "3", "3", "3", "3", null), ("4", "4", "corrupted", "4", "4", "4", "true"), ("5", "5", "5", "corrupted", "5", "5", "false"), ("6", "6", "6", "6", "corrupted", "6", null), ("7", "7", "7", "7", "7", "corrupted", "true"), ("8", "8", "8", "8", "8", "8", "corrupted") // @formatter:on ) // important! dataframe's column names must match parameter names of the case class passed to `.clean` method .toDF("col1", "col2", "col3", "col4", "col5", "col6", "col7") // import standard functions for conversions shipped with the library import cleanframes.instances.all._ val result = input // call cleanframes API .clean[AnyValsExample] // make Dataset .as[AnyValsExample] .collect import cleanframes.{AnyValsExample => Model} // just for readability sake result should { contain theSameElementsAs Seq( // @formatter:off Model(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true)), Model(None, Some(2), Some(2), Some(2), Some(2), Some(2), Some(false)), Model(Some(3), None, Some(3), Some(3), Some(3), Some(3), Some(false)), Model(Some(4), Some(4), None, Some(4), Some(4), Some(4), Some(true)), Model(Some(5), Some(5), Some(5), None, Some(5), Some(5), Some(false)), Model(Some(6), Some(6), Some(6), Some(6), None, Some(6), Some(false)), Model(Some(7), Some(7), Some(7), Some(7), Some(7), None, Some(true)), Model(Some(8), Some(8), Some(8), Some(8), Some(8), Some(8), Some(false)) // @formatter:on ) }.and(have size 8) } } case class AnyValsExample(col1: Option[Int], col2: Option[Byte], col3: Option[Short], col4: Option[Long], col5: Option[Float], col6: Option[Double], col7: Option[Boolean])
Example 24
Source File: SingleImportInsteadAllTest.scala From cleanframes with Apache License 2.0 | 5 votes |
package cleanframes import org.scalatest.{FlatSpec, Matchers} import com.holdenkarau.spark.testing.DataFrameSuiteBase class SingleImportInsteadAllTest extends FlatSpec with Matchers with DataFrameSuiteBase { "Cleaner" should "transform data by using concrete import" in { import spark.implicits._ // to use `.toDF` and `.as` import cleanframes.syntax._ // to use `.clean` // define test data for a dataframe val input = Seq( ("1"), ("corrupted"), ("3"), ("4"), ("5"), (null), ("null"), (" x "), (" 6 2 "), ("6"), ("7"), ("8") ) // important! dataframe's column names must match parameter names of the case class passed to `.clean` method .toDF("col1") // import standard functions for conversions shipped with the library import cleanframes.instances.int._ val result = input // call cleanframes API .clean[SingleIntModel] // make Dataset .as[SingleIntModel] .collect result should { contain theSameElementsAs Seq( SingleIntModel(Some(1)), SingleIntModel(None), SingleIntModel(Some(3)), SingleIntModel(Some(4)), SingleIntModel(Some(5)), SingleIntModel(None), SingleIntModel(None), SingleIntModel(None), SingleIntModel(None), SingleIntModel(Some(6)), SingleIntModel(Some(7)), SingleIntModel(Some(8)) ) } } } case class SingleIntModel(col1: Option[Int])
Example 25
Source File: BigQueryClientSpecs.scala From spark-bigquery with Apache License 2.0 | 4 votes |
package com.samelamin.spark.bigquery import java.io.File import com.google.api.services.bigquery.Bigquery import com.google.api.services.bigquery.model._ import com.google.cloud.hadoop.io.bigquery._ import com.holdenkarau.spark.testing.DataFrameSuiteBase import com.samelamin.spark.bigquery.converters.{BigQueryAdapter, SchemaConverters} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ import org.mockito.Matchers.{any, eq => mockitoEq} import org.mockito.Mockito._ import org.scalatest.FeatureSpec import org.scalatest.mock.MockitoSugar class BigQueryClientSpecs extends FeatureSpec with DataFrameSuiteBase with MockitoSugar { val BQProjectId = "google.com:foo-project" def setupBigQueryClient(sqlCtx: SQLContext, bigQueryMock: Bigquery): BigQueryClient = { val fakeJobReference = new JobReference() fakeJobReference.setProjectId(BQProjectId) fakeJobReference.setJobId("bigquery-job-1234") val dataProjectId = "publicdata" // Create the job result. val jobStatus = new JobStatus() jobStatus.setState("DONE") jobStatus.setErrorResult(null) val jobHandle = new Job() jobHandle.setStatus(jobStatus) jobHandle.setJobReference(fakeJobReference) // Create table reference. val tableRef = new TableReference() tableRef.setProjectId(dataProjectId) tableRef.setDatasetId("test_dataset") tableRef.setTableId("test_table") // Mock getting Bigquery jobs when(bigQueryMock.jobs().get(any[String], any[String]).execute()) .thenReturn(jobHandle) when(bigQueryMock.jobs().insert(any[String], any[Job]).execute()) .thenReturn(jobHandle) val bigQueryClient = new BigQueryClient(sqlCtx, bigQueryMock) bigQueryClient } scenario("When writing to BQ") { val sqlCtx = sqlContext import sqlCtx.implicits._ val gcsPath = "/tmp/testfile2.json" FileUtils.deleteQuietly(new File(gcsPath)) val adaptedDf = BigQueryAdapter(sc.parallelize(List(1, 2, 3)).toDF) val bigQueryMock = mock[Bigquery](RETURNS_DEEP_STUBS) val fullyQualifiedOutputTableId = "testProjectID:test_dataset.test" val targetTable = BigQueryStrings.parseTableReference(fullyQualifiedOutputTableId) val bigQueryClient = setupBigQueryClient(sqlCtx, bigQueryMock) val bigQuerySchema = SchemaConverters.SqlToBQSchema(adaptedDf) bigQueryClient.load(targetTable,bigQuerySchema,gcsPath) verify(bigQueryMock.jobs().insert(mockitoEq(BQProjectId),any[Job]), times(1)).execute() } scenario("When reading from BQ") { val sqlCtx = sqlContext val fullyQualifiedOutputTableId = "testProjectID:test_dataset.test" val sqlQuery = s"select * from $fullyQualifiedOutputTableId" val bqQueryContext = new BigQuerySQLContext(sqlCtx) bqQueryContext.setBigQueryProjectId(BQProjectId) val bigQueryMock = mock[Bigquery](RETURNS_DEEP_STUBS) val bigQueryClient = setupBigQueryClient(sqlCtx, bigQueryMock) bigQueryClient.selectQuery(sqlQuery) verify(bigQueryMock.jobs().insert(mockitoEq(BQProjectId),any[Job]), times(1)).execute() } scenario("When running a DML Queries") { val sqlCtx = sqlContext val fullyQualifiedOutputTableId = "testProjectID:test_dataset.test" val dmlQuery = s"UPDATE $fullyQualifiedOutputTableId SET test_col = new_value WHERE test_col = old_value" val bqQueryContext = new BigQuerySQLContext(sqlCtx) bqQueryContext.setBigQueryProjectId(BQProjectId) val bigQueryMock = mock[Bigquery](RETURNS_DEEP_STUBS) val bigQueryClient = setupBigQueryClient(sqlCtx, bigQueryMock) bigQueryClient.runDMLQuery(dmlQuery) verify(bigQueryMock.jobs().insert(mockitoEq(BQProjectId),any[Job]), times(1)).execute() } }