org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions Scala Examples
The following examples show how to use org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: PostgresDialect.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { Some(FloatType) } else if (sqlType == Types.SMALLINT) { Some(ShortType) } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { Some(StringType) } else if (sqlType == Types.ARRAY) { val scale = md.build.getLong("scale").toInt // postgres array type names start with underscore toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } private def toCatalystType( typeName: String, precision: Int, scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) case "int4" => Some(IntegerType) case "int8" | "oid" => Some(LongType) case "float4" => Some(FloatType) case "money" | "float8" => Some(DoubleType) case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => Some(StringType) case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) case _ => None } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) // According to the postgres jdbc documentation we need to be in autocommit=false if we actually // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the // rows inside the driver when fetching. // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } } override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }
Example 2
Source File: PostgresDialect.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { Some(FloatType) } else if (sqlType == Types.SMALLINT) { Some(ShortType) } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { Some(StringType) } else if (sqlType == Types.ARRAY) { val scale = md.build.getLong("scale").toInt // postgres array type names start with underscore toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } private def toCatalystType( typeName: String, precision: Int, scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) case "int4" => Some(IntegerType) case "int8" | "oid" => Some(LongType) case "float4" => Some(FloatType) case "money" | "float8" => Some(DoubleType) case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => Some(StringType) case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) case "numeric" | "decimal" if precision > 0 => Some(DecimalType.bounded(precision, scale)) case "numeric" | "decimal" => // SPARK-26538: handle numeric without explicit precision and scale. Some(DecimalType. SYSTEM_DEFAULT) case _ => None } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) override def getTruncateQuery( table: String, cascade: Option[Boolean] = isCascadingTruncateTable): String = { cascade match { case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE" case _ => s"TRUNCATE TABLE ONLY $table" } } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) // According to the postgres jdbc documentation we need to be in autocommit=false if we actually // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the // rows inside the driver when fetching. // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } } }
Example 3
Source File: JdbcOutput.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.jdbc import java.io.{Serializable => JSerializable} import java.util.Properties import com.stratio.sparta.sdk.pipeline.output.Output._ import com.stratio.sparta.sdk.pipeline.output.SaveModeEnum.SpartaSaveMode import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum} import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.jdbc.SpartaJdbcUtils import org.apache.spark.sql.jdbc.SpartaJdbcUtils._ import scala.collection.JavaConversions._ import scala.util.{Failure, Success, Try} class JdbcOutput(name: String, properties: Map[String, JSerializable]) extends Output(name, properties) { require(properties.getString("url", None).isDefined, "url must be provided") val url = properties.getString("url") override def supportedSaveModes : Seq[SpartaSaveMode] = Seq(SaveModeEnum.Append, SaveModeEnum.ErrorIfExists, SaveModeEnum.Ignore, SaveModeEnum.Overwrite) //scalastyle:off override def save(dataFrame: DataFrame, saveMode: SpartaSaveMode, options: Map[String, String]): Unit = { validateSaveMode(saveMode) val tableName = getTableNameFromOptions(options) val sparkSaveMode = getSparkSaveMode(saveMode) val connectionProperties = new JDBCOptions(url, tableName, propertiesWithCustom.mapValues(_.toString).filter(_._2.nonEmpty) ) Try { if (sparkSaveMode == SaveMode.Overwrite) SpartaJdbcUtils.dropTable(url, connectionProperties, tableName) SpartaJdbcUtils.tableExists(url, connectionProperties, tableName, dataFrame.schema) } match { case Success(tableExists) => if (tableExists) { if (saveMode == SaveModeEnum.Upsert) { val updateFields = getPrimaryKeyOptions(options) match { case Some(pk) => pk.split(",").toSeq case None => dataFrame.schema.fields.filter(stField => stField.metadata.contains(Output.PrimaryKeyMetadataKey)).map(_.name).toSeq } SpartaJdbcUtils.upsertTable(dataFrame, url, tableName, connectionProperties, updateFields) } if (saveMode == SaveModeEnum.Ignore) return if (saveMode == SaveModeEnum.ErrorIfExists) sys.error(s"Table $tableName already exists.") if (saveMode == SaveModeEnum.Append || saveMode == SaveModeEnum.Overwrite) SpartaJdbcUtils.saveTable(dataFrame, url, tableName, connectionProperties) } else log.warn(s"Table not created in Postgres: $tableName") case Failure(e) => closeConnection() log.error(s"Error creating/dropping table $tableName") } } override def cleanUp(options: Map[String, String]): Unit = { log.info(s"Closing connections in JDBC Output: $name") closeConnection() } }
Example 4
Source File: PostgresOutput.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.postgres import java.io.{InputStream, Serializable => JSerializable} import java.util.Properties import com.stratio.sparta.sdk.pipeline.output.Output._ import com.stratio.sparta.sdk.pipeline.output.SaveModeEnum.SpartaSaveMode import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum} import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.jdbc.SpartaJdbcUtils import org.apache.spark.sql.jdbc.SpartaJdbcUtils._ import org.postgresql.copy.CopyManager import org.postgresql.core.BaseConnection import scala.collection.JavaConversions._ import scala.util.{Failure, Success, Try} class PostgresOutput(name: String, properties: Map[String, JSerializable]) extends Output(name, properties) { require(properties.getString("url", None).isDefined, "Postgres url must be provided") val url = properties.getString("url") val bufferSize = properties.getString("bufferSize", "65536").toInt val delimiter = properties.getString("delimiter", "\t") val newLineSubstitution = properties.getString("newLineSubstitution", " ") val encoding = properties.getString("encoding", "UTF8") override def supportedSaveModes: Seq[SpartaSaveMode] = Seq(SaveModeEnum.Append, SaveModeEnum.Overwrite, SaveModeEnum.Upsert) override def save(dataFrame: DataFrame, saveMode: SpartaSaveMode, options: Map[String, String]): Unit = { validateSaveMode(saveMode) val tableName = getTableNameFromOptions(options) val sparkSaveMode = getSparkSaveMode(saveMode) val connectionProperties = new JDBCOptions(url, tableName, propertiesWithCustom.mapValues(_.toString).filter(_._2.nonEmpty) ) Try { if (sparkSaveMode == SaveMode.Overwrite) SpartaJdbcUtils.dropTable(url, connectionProperties, tableName) SpartaJdbcUtils.tableExists(url, connectionProperties, tableName, dataFrame.schema) } match { case Success(tableExists) => if (tableExists) if (saveMode == SaveModeEnum.Upsert) { val updateFields = getPrimaryKeyOptions(options) match { case Some(pk) => pk.split(",").toSeq case None => dataFrame.schema.fields.filter(stField => stField.metadata.contains(Output.PrimaryKeyMetadataKey)).map(_.name).toSeq } SpartaJdbcUtils.upsertTable(dataFrame, url, tableName, connectionProperties, updateFields) } else { dataFrame.foreachPartition { rows => val conn = getConnection(connectionProperties) val cm = new CopyManager(conn.asInstanceOf[BaseConnection]) cm.copyIn( s"""COPY $tableName FROM STDIN WITH (NULL 'null', ENCODING '$encoding', FORMAT CSV, DELIMITER E'$delimiter')""", rowsToInputStream(rows) ) } } else log.warn(s"Table not created in Postgres: $tableName") case Failure(e) => closeConnection() log.error(s"Error creating/dropping table $tableName") } } def rowsToInputStream(rows: Iterator[Row]): InputStream = { val bytes: Iterator[Byte] = rows.flatMap { row => (row.mkString(delimiter).replace("\n", newLineSubstitution) + "\n").getBytes(encoding) } new InputStream { override def read(): Int = if (bytes.hasNext) bytes.next & 0xff else -1 } } override def cleanUp(options: Map[String, String]): Unit = { log.info(s"Closing connections in Postgres Output: $name") closeConnection() } }
Example 5
Source File: PostgresDialect.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { Some(FloatType) } else if (sqlType == Types.SMALLINT) { Some(ShortType) } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { Some(StringType) } else if (sqlType == Types.ARRAY) { val scale = md.build.getLong("scale").toInt // postgres array type names start with underscore toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } private def toCatalystType( typeName: String, precision: Int, scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) case "int4" => Some(IntegerType) case "int8" | "oid" => Some(LongType) case "float4" => Some(FloatType) case "money" | "float8" => Some(DoubleType) case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => Some(StringType) case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) case _ => None } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) // According to the postgres jdbc documentation we need to be in autocommit=false if we actually // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the // rows inside the driver when fetching. // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } } override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }
Example 6
Source File: MemsqlRDD.scala From memsql-spark-connector with Apache License 2.0 | 5 votes |
package com.memsql.spark import java.sql.{Connection, PreparedStatement, ResultSet} import com.memsql.spark.SQLGen.VariableList import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} case class MemsqlRDD(query: String, variables: VariableList, options: MemsqlOptions, schema: StructType, expectedOutput: Seq[Attribute], @transient val sc: SparkContext) extends RDD[Row](sc, Nil) { override protected def getPartitions: Array[Partition] = MemsqlQueryHelpers.GetPartitions(options, query, variables) override def compute(rawPartition: Partition, context: TaskContext): Iterator[Row] = { var closed = false var rs: ResultSet = null var stmt: PreparedStatement = null var conn: Connection = null var partition: MemsqlPartition = rawPartition.asInstanceOf[MemsqlPartition] def tryClose(name: String, what: AutoCloseable): Unit = { try { if (what != null) { what.close() } } catch { case e: Exception => logWarning(s"Exception closing $name", e) } } def close(): Unit = { if (closed) { return } tryClose("resultset", rs) tryClose("statement", stmt) tryClose("connection", conn) closed = true } context.addTaskCompletionListener { context => close() } conn = JdbcUtils.createConnectionFactory(partition.connectionInfo)() stmt = conn.prepareStatement(partition.query) JdbcHelpers.fillStatement(stmt, partition.variables) rs = stmt.executeQuery() var rowsIter = JdbcUtils.resultSetToRows(rs, schema) if (expectedOutput.nonEmpty) { val schemaDatatypes = schema.map(_.dataType) val expectedDatatypes = expectedOutput.map(_.dataType) if (schemaDatatypes != expectedDatatypes) { val columnEncoders = schemaDatatypes.zip(expectedDatatypes).zipWithIndex.map { case ((_: StringType, _: NullType), _) => ((_: Row) => null) case ((_: ShortType, _: BooleanType), i) => ((r: Row) => r.getShort(i) != 0) case ((_: IntegerType, _: BooleanType), i) => ((r: Row) => r.getInt(i) != 0) case ((_: LongType, _: BooleanType), i) => ((r: Row) => r.getLong(i) != 0) case ((l, r), i) => { options.assert(l == r, s"MemsqlRDD: unable to encode ${l} into ${r}") ((r: Row) => r.get(i)) } } rowsIter = rowsIter .map(row => Row.fromSeq(columnEncoders.map(_(row)))) } } CompletionIterator[Row, Iterator[Row]](new InterruptibleIterator[Row](context, rowsIter), close) } }
Example 7
Source File: JdbcSinkDemo.scala From bahir with Apache License 2.0 | 5 votes |
package org.apache.bahir.examples.sql.streaming.jdbc import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.streaming.{OutputMode, Trigger} object JdbcSinkDemo { private case class Person(name: String, age: Int) def main(args: Array[String]): Unit = { if (args.length < 4) { // scalastyle:off println System.err.println("Usage: JdbcSinkDemo <jdbcUrl> <tableName> <username> <password>") // scalastyle:on System.exit(1) } val jdbcUrl = args(0) val tableName = args(1) val username = args(2) val password = args(3) val spark = SparkSession .builder() .appName("JdbcSinkDemo") .getOrCreate() // load data source val df = spark.readStream .format("rate") .option("numPartitions", "5") .option("rowsPerSecond", "100") .load() // change input value to a person object. import spark.implicits._ val lines = df.select("value").as[Long].map{ value => Person(s"name_${value}", value.toInt % 30) } lines.printSchema() // write result val query = lines.writeStream .outputMode("append") .format("streaming-jdbc") .outputMode(OutputMode.Append) .option(JDBCOptions.JDBC_URL, jdbcUrl) .option(JDBCOptions.JDBC_TABLE_NAME, tableName) .option(JDBCOptions.JDBC_DRIVER_CLASS, "com.mysql.jdbc.Driver") .option(JDBCOptions.JDBC_BATCH_INSERT_SIZE, "5") .option("user", username) .option("password", password) .trigger(Trigger.ProcessingTime("10 seconds")) .start() query.awaitTermination() } }
Example 8
Source File: JdbcSourceProvider.scala From bahir with Apache License 2.0 | 5 votes |
package org.apache.bahir.sql.streaming.jdbc import scala.collection.JavaConverters._ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType class JdbcSourceProvider extends StreamWriteSupport with DataSourceRegister{ override def createStreamWriter(queryId: String, schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { val optionMap = options.asMap().asScala.toMap // add this for parameter check. new JDBCOptions(optionMap) new JdbcStreamWriter(schema, optionMap) } // short name 'jdbc' is used for batch, chose a different name for streaming. override def shortName(): String = "streaming-jdbc" }
Example 9
Source File: PostgresDialect.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { Some(FloatType) } else if (sqlType == Types.SMALLINT) { Some(ShortType) } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { Some(StringType) } else if (sqlType == Types.ARRAY) { val scale = md.build.getLong("scale").toInt // postgres array type names start with underscore toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } private def toCatalystType( typeName: String, precision: Int, scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) case "int4" => Some(IntegerType) case "int8" | "oid" => Some(LongType) case "float4" => Some(FloatType) case "money" | "float8" => Some(DoubleType) case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => Some(StringType) case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) case _ => None } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) // According to the postgres jdbc documentation we need to be in autocommit=false if we actually // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the // rows inside the driver when fetching. // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } } override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }
Example 10
Source File: PostgresDialect.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { Some(FloatType) } else if (sqlType == Types.SMALLINT) { Some(ShortType) } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { Some(StringType) } else if (sqlType == Types.ARRAY) { val scale = md.build.getLong("scale").toInt // postgres array type names start with underscore toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } private def toCatalystType( typeName: String, precision: Int, scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) case "int4" => Some(IntegerType) case "int8" | "oid" => Some(LongType) case "float4" => Some(FloatType) case "money" | "float8" => Some(DoubleType) case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => Some(StringType) case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) case _ => None } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } override def getTruncateQuery(table: String): String = { s"TRUNCATE TABLE ONLY $table" } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) // According to the postgres jdbc documentation we need to be in autocommit=false if we actually // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the // rows inside the driver when fetching. // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) }