org.apache.spark.sql.expressions.UserDefinedFunction Scala Examples
The following examples show how to use org.apache.spark.sql.expressions.UserDefinedFunction.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: CustomUDF.scala From jgit-spark-connector with Apache License 2.0 | 5 votes |
package tech.sourced.engine.udf import org.apache.spark.groupon.metrics.{NotInitializedException, SparkTimer, UserMetricsSystem} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.expressions.UserDefinedFunction def apply(session: SparkSession): UserDefinedFunction def apply(): UserDefinedFunction = this.apply(session = null) } sealed class SparkTimerUDFWrapper(name: String) extends Logging { lazy val timer: SparkTimer = init() private def init(): SparkTimer = { try { UserMetricsSystem.timer(name) } catch { case _: NotInitializedException => { logWarning("SparkMetric not initialized on UDF") null } } } def time[T](f: => T): T = if (timer == null) { f } else { timer.time(f) } }
Example 2
Source File: functionsSuite.scala From spark-corenlp with GNU General Public License v3.0 | 5 votes |
package com.databricks.spark.corenlp import scala.reflect.runtime.universe.TypeTag import com.databricks.spark.corenlp.functions._ import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ class functionsSuite extends SparkFunSuite { private val sentence1 = "Stanford University is located in California." private val sentence2 = "It is a great university." private val document = s"$sentence1 $sentence2" private val xml = s"<xml><p>$sentence1</p><p>$sentence2</p></xml>" private def testFunction[T: TypeTag](function: UserDefinedFunction, input: T, expected: Any): Unit = { val df = sqlContext.createDataFrame(Seq((0, input))).toDF("id", "input") val actual = df.select(function(col("input"))).first().get(0) assert(actual === expected) } test("ssplit") { testFunction(ssplit, document, Seq(sentence1, sentence2)) } test("tokenize") { val expected = Seq("Stanford", "University", "is", "located", "in", "California", ".") testFunction(tokenize, sentence1, expected) } test("pos") { val expected = Seq("NNP", "NNP", "VBZ", "JJ", "IN", "NNP", ".") testFunction(pos, sentence1, expected) } test("lemma") { val expected = Seq("Stanford", "University", "be", "located", "in", "California", ".") testFunction(lemma, sentence1, expected) } test("ner") { val expected = Seq("ORGANIZATION", "ORGANIZATION", "O", "O", "O", "STATE_OR_PROVINCE", "O") testFunction(ner, sentence1, expected) } test("natlog") { val expected = Seq("up", "up", "up", "up", "up", "up", "up") testFunction(natlog, sentence1, expected) } test("cleanxml") { val expected = "Stanford University is located in California . It is a great university ." testFunction(cleanxml, xml, expected) } test("coref") { val expected = Seq( Row("Stanford University", Seq( Row(1, 1, "Stanford University"), Row(2, 1, "It")))) testFunction(coref, document, expected) } test("depparse") { val expected = Seq( Row("University", 2, "compound", "Stanford", 1, 1.0), Row("located", 4, "nsubjpass", "University", 2, 1.0), Row("located", 4, "auxpass", "is", 3, 1.0), Row("California", 6, "case", "in", 5, 1.0), Row("located", 4, "nmod:in", "California", 6, 1.0), Row("located", 4, "punct", ".", 7, 1.0)) testFunction(depparse, sentence1, expected) } test("openie") { val expected = Seq( Row("Stanford University", "is", "located", 1.0), Row("Stanford University", "is located in", "California", 1.0)) testFunction(openie, sentence1, expected) } test("sentiment") { testFunction(sentiment, sentence1, 1) testFunction(sentiment, sentence2, 4) testFunction(sentiment, document, 1) // only look at the first sentence } }
Example 3
Source File: UDFBuilder.scala From sope with Apache License 2.0 | 5 votes |
package com.sope.etl.register import java.io.File import java.net.URLClassLoader import com.sope.etl.getObjectInstance import com.sope.etl.transform.exception.YamlDataTransformException import com.sope.etl.utils.JarUtils import com.sope.utils.Logging import org.apache.commons.io.FileUtils import org.apache.spark.sql.expressions.UserDefinedFunction import scala.tools.nsc.Settings import scala.tools.nsc.interpreter.IMain object UDFBuilder extends Logging { val DefaultClassLocation = "/tmp/sope/dynamic/" val DefaultJarLocation = "/tmp/sope/sope-dynamic-udf.jar" def buildDynamicUDFs(udfCodeMap: Map[String, String]): Map[String, UserDefinedFunction] = { val file = new java.io.File(UDFBuilder.DefaultClassLocation) FileUtils.deleteDirectory(file) file.mkdirs() val udfMap = evalUDF(udfCodeMap) JarUtils.buildJar(DefaultClassLocation, DefaultJarLocation) udfMap } }
Example 4
Source File: UDFRegistration.scala From sope with Apache License 2.0 | 5 votes |
package com.sope.etl.register import com.sope.etl.{SopeETLConfig, getClassInstance} import com.sope.utils.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.expressions.UserDefinedFunction def registerCustomUDFs(sqlContext: SQLContext): Unit = { SopeETLConfig.UDFRegistrationConfig match { case Some(classStr) => logInfo(s"Registering custom UDFs from $classStr") getClassInstance[UDFRegistration](classStr) match { case Some(udfClass) => udfClass.performRegistration(sqlContext) logInfo("Successfully registered custom UDFs") case _ => logError(s"UDF Registration failed") } case None => logInfo("No class defined for registering Custom udfs") } } }
Example 5
Source File: DataFrameTfrConverter.scala From ecosystem with Apache License 2.0 | 5 votes |
package org.tensorflow.spark.datasources.tfrecords.udf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.udf import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder object DataFrameTfrConverter { def getRowToTFRecordExampleUdf: UserDefinedFunction = udf(rowToTFRecordExampleUdf _ ) private def rowToTFRecordExampleUdf(row: Row): Array[Byte] = { DefaultTfRecordRowEncoder.encodeExample(row).toByteArray } def getRowToTFRecordSequenceExampleUdf: UserDefinedFunction = udf(rowToTFRecordSequenceExampleUdf _ ) private def rowToTFRecordSequenceExampleUdf(row: Row): Array[Byte] = { DefaultTfRecordRowEncoder.encodeSequenceExample(row).toByteArray } }
Example 6
Source File: UDFs.scala From albedo with MIT License | 5 votes |
package ws.vinta.albedo.closures import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import ws.vinta.albedo.closures.StringFunctions._ import scala.util.control.Breaks.{break, breakable} object UDFs extends Serializable { def containsAnyOfUDF(substrings: Array[String], shouldLower: Boolean = false): UserDefinedFunction = udf[Double, String]((text: String) => { var result = 0.0 breakable { for (substring <- substrings) { if (text.contains(substring)) { result = 1.0 break } } } result }) def toArrayUDF: UserDefinedFunction = udf[Array[Double], Vector]((vector: Vector) => { vector.toArray }) def numNonzerosOfVectorUDF: UserDefinedFunction = udf[Int, Vector]((vector: Vector) => { vector.numNonzeros }) def cleanCompanyUDF: UserDefinedFunction = udf[String, String]((company: String) => { val temp1 = company .toLowerCase() .replaceAll("""\b(.com|.net|.org|.io|.co.uk|.co|.eu|.fr|.de|.ru)\b""", "") .replaceAll("""\b(formerly|previously|ex\-)\b""", "") .replaceAll("""\W+""", " ") .replaceAll("""\s+""", " ") .replaceAll("""\b(http|https|www|co ltd|pvt ltd|ltd|inc|llc)\b""", "") .trim() val temp2 = extractWordsIncludeCJK(temp1).mkString(" ") if (temp2.isEmpty) "__empty" else temp2 }) def cleanEmailUDF: UserDefinedFunction = udf[String, String]((email: String) => { val temp1 = email.toLowerCase().trim() val temp2 = extractEmailDomain(temp1) if (temp2.isEmpty) "__empty" else temp2 }) def cleanLocationUDF: UserDefinedFunction = udf[String, String]((location: String) => { val temp1 = try { val pattern = s"([$wordPatternIncludeCJK]+),\\s*([$wordPatternIncludeCJK]+)".r val pattern(city, _) = location city } catch { case _: MatchError => { location } } val temp2 = temp1 .toLowerCase() .replaceAll("""[~!@#$^%&*\\(\\)_+={}\\[\\]|;:\"'<,>.?`/\\\\-]+""", " ") .replaceAll("""\s+""", " ") .replaceAll("""\b(city)\b""", "") .trim() val temp3 = extractWordsIncludeCJK(temp2).mkString(" ") if (temp3.isEmpty) "__empty" else temp3 }) def repoLanguageIndexInUserRecentRepoLanguagesUDF = udf((repo_language: String, user_recent_repo_languages: Seq[String]) => { val index = user_recent_repo_languages.indexOf(repo_language.toLowerCase()) if (index < 0) user_recent_repo_languages.size + 50 else index }) def repoLanguageCountInUserRecentRepoLanguagesUDF = udf((repo_language: String, user_recent_repo_languages: Seq[String]) => { user_recent_repo_languages.count(_ == repo_language.toLowerCase()) }) }
Example 7
Source File: SparkStreamingPCatalogUSDemo.scala From gimel with Apache License 2.0 | 5 votes |
package com.paypal.gimel.examples import org.apache.spark.sql._ import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.streaming._ import com.paypal.gimel.{DataSet, DataStream} import com.paypal.gimel.logger.Logger object SparkStreamingPCatalogUSDemo { // Define Geo Function case class Geo(lat: Double, lon: Double) val myUDF: UserDefinedFunction = udf((lat: Double, lon: Double) => Geo(lat, lon)) def main(args: Array[String]) { // Creating SparkContext val sparkSession = SparkSession .builder() .enableHiveSupport() .getOrCreate() val sc = sparkSession.sparkContext sc.setLogLevel("ERROR") val sqlContext = sparkSession.sqlContext val ssc = new StreamingContext(sc, Seconds(10)) val logger = Logger(this.getClass.getName) // Initiating PCatalog DataSet and DataStream val dataSet = DataSet(sparkSession) val dataStream = DataStream(ssc) // Reading from HDFS Dataset logger.info("Reading address_geo HDFS Dataset") val geoLookUpDF = dataSet.read("pcatalog.address_geo") val geoLookUp = geoLookUpDF.withColumn("geo", myUDF(geoLookUpDF("lat"), geoLookUpDF("lon"))).drop("lat").drop("lon") geoLookUp.cache() logger.info("Read" + geoLookUp.count() + " records") // Reading from Kafka DataStream and Loading into Elastic Search Dataset val streamingResult = dataStream.read("pcatalog.kafka_transactions") streamingResult.clearCheckPoint("OneTimeOnly") streamingResult.dStream.foreachRDD { rdd => if (rdd.count() > 0) { streamingResult.getCurrentCheckPoint(rdd) val txnDF = streamingResult.convertAvroToDF(sqlContext, streamingResult.convertBytesToAvro(rdd)) val resultSet = txnDF.join(geoLookUp, txnDF("account_number") === geoLookUp("customer_id")) .selectExpr("CONCAT(time_created,'000') AS time_created", "geo", "usd_amount") dataSet.write("pcatalog.elastic_transactions_dmz", resultSet) streamingResult.saveCurrentCheckPoint() } } // Start Streaming dataStream.streamingContext.start() dataStream.streamingContext.awaitTermination() sc.stop() } }
Example 8
Source File: ClassifyLanguagesUDF.scala From jgit-spark-connector with Apache License 2.0 | 5 votes |
package tech.sourced.engine.udf import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.udf import tech.sourced.enry.Enry def getLanguage(isBinary: Boolean, path: String, content: Array[Byte]): Option[String] = { timer.time({ if (isBinary) { None } else { val lang = try { Enry.getLanguage(path, content) } catch { case e@(_: RuntimeException | _: Exception) => log.error(s"get language for file '$path' failed", e) null } if (null == lang || lang.isEmpty) None else Some(lang) } }) } }
Example 9
Source File: ExtractUASTsUDF.scala From jgit-spark-connector with Apache License 2.0 | 5 votes |
package tech.sourced.engine.udf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.udf import tech.sourced.engine.util.Bblfsh trait ExtractUASTsUDF { def extractUASTs(path: String, content: Array[Byte], lang: String = null, config: Bblfsh.Config): Seq[Array[Byte]] = { if (content == null || content.isEmpty) { Seq() } else { Bblfsh.extractUAST(path, content, lang, config) } } } case object ExtractUASTsUDF extends CustomUDF with ExtractUASTsUDF { override val name = "extractUASTs" override def apply(session: SparkSession): UserDefinedFunction = { val configB = session.sparkContext.broadcast(Bblfsh.getConfig(session)) udf[Seq[Array[Byte]], String, Array[Byte], String]((path, content, lang) => extractUASTs(path, content, lang, configB.value)) } }
Example 10
Source File: Utils.scala From Mastering-Machine-Learning-with-Spark-2.x with MIT License | 5 votes |
package com.packtpub.mmlwspark.utils import org.apache.spark.h2o.H2OContext import org.apache.spark.sql.SQLContext import org.apache.spark.sql.expressions.UserDefinedFunction import water.fvec.H2OFrame object Utils { def colTransform(hf: H2OFrame, udf: UserDefinedFunction, colName: String)(implicit h2oContext: H2OContext, sqlContext: SQLContext): H2OFrame = { import sqlContext.implicits._ val name = hf.key.toString val colHf = hf(Array(colName)) val df = h2oContext.asDataFrame(colHf) val result = h2oContext.asH2OFrame(df.withColumn(colName, udf($"${colName}")), s"${name}_${colName}") colHf.delete() result } def let[A](in: A)(body: A => Unit) = { body(in) in } }
Example 11
Source File: QueryXPathUDF.scala From jgit-spark-connector with Apache License 2.0 | 5 votes |
package tech.sourced.engine.udf import gopkg.in.bblfsh.sdk.v1.uast.generated.Node import org.apache.spark.sql.SparkSession import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.udf import tech.sourced.engine.util.Bblfsh case object QueryXPathUDF extends CustomUDF { override val name = "queryXPath" override def apply(session: SparkSession): UserDefinedFunction = { val configB = session.sparkContext.broadcast(Bblfsh.getConfig(session)) udf[Seq[Array[Byte]], Seq[Array[Byte]], String]((nodes, query) => queryXPath(nodes, query, configB.value)) } private def queryXPath(nodes: Seq[Array[Byte]], query: String, config: Bblfsh.Config): Seq[Array[Byte]] = { timer.time({ if (nodes == null) { return null } nodes.map(Node.parseFrom).flatMap(n => { val result = Bblfsh.filter(n, query, config) if (result == null) { None } else { result.toIterator } }).map(_.toByteArray) }) } }
Example 12
Source File: ExtractTokensUDF.scala From jgit-spark-connector with Apache License 2.0 | 5 votes |
package tech.sourced.engine.udf import gopkg.in.bblfsh.sdk.v1.uast.generated.Node import org.apache.spark.sql.SparkSession import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.udf case object ExtractTokensUDF extends CustomUDF { override val name = "extractTokens" override def apply(session: SparkSession): UserDefinedFunction = udf[Seq[String], Seq[Array[Byte]]](extractTokens) private def extractTokens(nodes: Seq[Array[Byte]]): Seq[String] = { timer.time({ if (nodes == null) { Seq() } else { nodes.map(Node.parseFrom).map(_.token) } }) } }
Example 13
Source File: udfs.scala From mmlspark with MIT License | 5 votes |
// Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. package com.microsoft.ml.spark.stages import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.Column import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.DoubleType import scala.collection.mutable //scalastyle:off object udfs { def get_value_at(colName: String, i: Int): Column = { udf({ vec: org.apache.spark.ml.linalg.Vector => vec(i) }, DoubleType)(col(colName)) } val to_vector: UserDefinedFunction = udf({ arr: Seq[Double] => Vectors.dense(arr.toArray) }, VectorType) def to_vector(colName: String): Column = to_vector(col(colName)) }
Example 14
Source File: UDFTransformer.scala From mmlspark with MIT License | 5 votes |
// Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. package com.microsoft.ml.spark.stages import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasInputCols, HasOutputCol, Wrappable} import com.microsoft.ml.spark.core.env.InternalWrapper import com.microsoft.ml.spark.core.serialize.ComplexParam import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} import org.apache.spark.ml.param.{ParamMap, UDFParam, UDPyFParam} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.col object UDFTransformer extends ComplexParamsReadable[UDFTransformer] override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (isSet(inputCol)) { dataset.withColumn(getOutputCol, applyUDF(dataset.col(getInputCol))) } else { dataset.withColumn(getOutputCol, applyUDFOnCols(getInputCols.map(col): _*)) } } def validateAndTransformSchema(schema: StructType): StructType = { if (isSet(inputCol)) schema(getInputCol) else schema(Set(getInputCols: _*)) schema.add(StructField(getOutputCol, getDataType)) } def transformSchema(schema: StructType): StructType = validateAndTransformSchema(schema) def copy(extra: ParamMap): UDFTransformer = defaultCopy(extra) }
Example 15
Source File: IndexToValue.scala From mmlspark with MIT License | 5 votes |
// Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. package com.microsoft.ml.spark.featurize import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasOutputCol, Wrappable} import com.microsoft.ml.spark.core.schema.{CategoricalColumnInfo, CategoricalUtilities} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import com.microsoft.ml.spark.core.schema.SchemaConstants._ import scala.reflect.ClassTag import reflect.runtime.universe.TypeTag object IndexToValue extends DefaultParamsReadable[IndexToValue] override def transform(dataset: Dataset[_]): DataFrame = { val info = new CategoricalColumnInfo(dataset.toDF(), getInputCol) require(info.isCategorical, "column " + getInputCol + "is not Categorical") val dataType = info.dataType val getLevel = dataType match { case _: IntegerType => getLevelUDF[Int](dataset) case _: LongType => getLevelUDF[Long](dataset) case _: DoubleType => getLevelUDF[Double](dataset) case _: StringType => getLevelUDF[String](dataset) case _: BooleanType => getLevelUDF[Boolean](dataset) case _ => throw new Exception("Unsupported type " + dataType.toString) } dataset.withColumn(getOutputCol, getLevel(dataset(getInputCol)).as(getOutputCol)) } private class Default[T] {var value: T = _ } def getLevelUDF[T: TypeTag](dataset: Dataset[_])(implicit ct: ClassTag[T]): UserDefinedFunction = { val map = CategoricalUtilities.getMap[T](dataset.schema(getInputCol).metadata) udf((index: Int) => { if (index == map.numLevels && map.hasNullLevel) { new Default[T].value } else { map.getLevelOption(index) .getOrElse(throw new IndexOutOfBoundsException( "Invalid metadata: Index greater than number of levels in metadata, " + s"index: $index, levels: ${map.numLevels}")) } }) } def transformSchema(schema: StructType): StructType = { val metadata = schema(getInputCol).metadata val dataType = if (metadata.contains(MMLTag)) { CategoricalColumnInfo.getDataType(metadata, throwOnInvalid = true).get } else { schema(getInputCol).dataType } val newField = StructField(getOutputCol, dataType) if (schema.fieldNames.contains(getOutputCol)) { val index = schema.fieldIndex(getOutputCol) val fields = schema.fields fields(index) = newField StructType(fields) } else { schema.add(newField) } } def copy(extra: ParamMap): this.type = defaultCopy(extra) }
Example 16
Source File: ServingUDFs.scala From mmlspark with MIT License | 5 votes |
// Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. package org.apache.spark.sql.execution.streaming import com.microsoft.ml.spark.io.http.HTTPResponseData import com.microsoft.ml.spark.io.http.HTTPSchema.{binary_to_response, empty_response, string_to_response} import org.apache.spark.sql.execution.streaming.continuous.HTTPSourceStateHolder import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{lit, struct, to_json, udf} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, Row} import scala.util.Try object ServingUDFs { private def jsonReply(c: Column) = string_to_response(to_json(c)) def makeReplyUDF(data: Column, dt: DataType, code: Column = lit(200), reason: Column = lit("Success")): Column = { dt match { case NullType => empty_response(code, reason) case StringType => string_to_response(data, code, reason) case BinaryType => binary_to_response(data) case _: StructType => jsonReply(data) case _: MapType => jsonReply(data) case at: ArrayType => at.elementType match { case _: StructType => jsonReply(data) case _: MapType => jsonReply(data) case _ => jsonReply(struct(data)) } case _ => jsonReply(struct(data)) } } private def sendReplyHelper(mapper: Row => HTTPResponseData)(serviceName: String, reply: Row, id: Row): Boolean = { if (Option(reply).isEmpty || Option(id).isEmpty) { null.asInstanceOf[Boolean] //scalastyle:ignore null } else { Try(HTTPSourceStateHolder.getServer(serviceName).replyTo(id.getString(0), id.getString(1), mapper(reply))) .toOption.isDefined } } def sendReplyUDF: UserDefinedFunction = { val toData = HTTPResponseData.makeFromRowConverter udf(sendReplyHelper(toData) _, BooleanType) } }
Example 17
Source File: package.scala From osmesa with Apache License 2.0 | 5 votes |
package osmesa.analytics.stats import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import vectorpipe.util._ package object functions { // A brief note about style // Spark functions are typically defined using snake_case, therefore so are the UDFs // internal helper functions use standard Scala naming conventions lazy val merge_measurements: UserDefinedFunction = udf(_mergeDoubleCounts) lazy val sum_measurements: UserDefinedFunction = udf { counts: Iterable[Map[String, Double]] => Option(counts.reduce(_mergeDoubleCounts)).filter(_.nonEmpty).orNull } lazy val sum_count_values: UserDefinedFunction = udf { counts: Map[String, Int] => counts.values.sum } lazy val simplify_measurements: UserDefinedFunction = udf { counts: Map[String, Double] => counts.filter(_._2 != 0) } lazy val simplify_counts: UserDefinedFunction = udf { counts: Map[String, Int] => counts.filter(_._2 != 0) } private val _mergeIntCounts = (a: Map[String, Int], b: Map[String, Int]) => mergeMaps(Option(a).getOrElse(Map.empty), Option(b).getOrElse(Map.empty))(_ + _) private val _mergeDoubleCounts = (a: Map[String, Double], b: Map[String, Double]) => mergeMaps(Option(a).getOrElse(Map.empty), Option(b).getOrElse(Map.empty))(_ + _) }
Example 18
Source File: functions.scala From spark-nlp with Apache License 2.0 | 5 votes |
package com.johnsnowlabs.nlp import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{array, col, explode, udf} import org.apache.spark.sql.types.DataType import scala.reflect.runtime.universe._ object functions { implicit class FilterAnnotations(dataset: DataFrame) { def filterByAnnotationsCol(column: String, function: Seq[Annotation] => Boolean): DataFrame = { val meta = dataset.schema(column).metadata val func = udf { annotatorProperties: Seq[Row] => function(annotatorProperties.map(Annotation(_))) } dataset.filter(func(col(column)).as(column, meta)) } } def mapAnnotations[T](function: Seq[Annotation] => T, outputType: DataType): UserDefinedFunction = udf ( { annotatorProperties: Seq[Row] => function(annotatorProperties.map(Annotation(_))) }, outputType) def mapAnnotationsStrict(function: Seq[Annotation] => Seq[Annotation]): UserDefinedFunction = udf { annotatorProperties: Seq[Row] => function(annotatorProperties.map(Annotation(_))) } implicit class MapAnnotations(dataset: DataFrame) { def mapAnnotationsCol[T: TypeTag](column: String, outputCol: String, function: Seq[Annotation] => T): DataFrame = { val meta = dataset.schema(column).metadata val func = udf { annotatorProperties: Seq[Row] => function(annotatorProperties.map(Annotation(_))) } dataset.withColumn(outputCol, func(col(column)).as(outputCol, meta)) } } implicit class EachAnnotations(dataset: DataFrame) { import dataset.sparkSession.implicits._ def eachAnnotationsCol[T: TypeTag](column: String, function: Seq[Annotation] => Unit): Unit = { dataset.select(column).as[Array[Annotation]].foreach(function(_)) } } implicit class ExplodeAnnotations(dataset: DataFrame) { def explodeAnnotationsCol[T: TypeTag](column: String, outputCol: String): DataFrame = { val meta = dataset.schema(column).metadata dataset. withColumn(outputCol, explode(col(column))). withColumn(outputCol, array(col(outputCol)).as(outputCol, meta)) } } }