org.apache.spark.sql.jdbc.JdbcType Scala Examples
The following examples show how to use org.apache.spark.sql.jdbc.JdbcType.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: MemsqlDialect.scala From memsql-spark-connector with Apache License 2.0 | 5 votes |
package com.memsql.spark import java.sql.Types import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType} import org.apache.spark.sql.types._ case object MemsqlDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:memsql") val MEMSQL_DECIMAL_MAX_SCALE = 30 override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case BooleanType => Option(JdbcType("BOOL", Types.BOOLEAN)) case ByteType => Option(JdbcType("TINYINT", Types.TINYINT)) case ShortType => Option(JdbcType("SMALLINT", Types.SMALLINT)) case FloatType => Option(JdbcType("FLOAT", Types.FLOAT)) case TimestampType => Option(JdbcType("TIMESTAMP(6)", Types.TIMESTAMP)) case dt: DecimalType if (dt.scale <= MEMSQL_DECIMAL_MAX_SCALE) => Option(JdbcType(s"DECIMAL(${dt.precision}, ${dt.scale})", Types.DECIMAL)) case dt: DecimalType => throw new IllegalArgumentException( s"Too big scale specified(${dt.scale}). MemSQL DECIMAL maximum scale is ${MEMSQL_DECIMAL_MAX_SCALE}") case NullType => throw new IllegalArgumentException( "No corresponding MemSQL type found for NullType. If you want to use NullType, please write to an already existing MemSQL table.") case t => JdbcUtils.getCommonJDBCType(t) } override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { (sqlType, typeName) match { case (Types.REAL, "FLOAT") => Option(FloatType) case (Types.BIT, "BIT") => Option(BinaryType) case (Types.TINYINT, "TINYINT") => Option(ShortType) case (Types.SMALLINT, "SMALLINT") => Option(ShortType) case (Types.DECIMAL, "DECIMAL") => { if (size > DecimalType.MAX_PRECISION) { throw new IllegalArgumentException( s"DECIMAL precision ${size} exceeds max precision ${DecimalType.MAX_PRECISION}") } else { Option( DecimalType(size, md.build().getLong("scale").toInt) ) } } case _ => None } } override def quoteIdentifier(colName: String): String = { s"`$colName`" } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) }
Example 2
Source File: JdbcUtil.scala From bahir with Apache License 2.0 | 5 votes |
package org.apache.bahir.sql.streaming.jdbc import java.sql.{Connection, PreparedStatement} import java.util.Locale import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object JdbcUtil { def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).getOrElse( throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) } // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for // `PreparedStatement`. The last argument `Int` means the index for the value to be set // in the SQL statement and also used for the value in `Row`. type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit def makeSetter( conn: Connection, dialect: JdbcDialect, dataType: DataType): JDBCValueSetter = dataType match { case IntegerType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getInt(pos)) case LongType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setLong(pos + 1, row.getLong(pos)) case DoubleType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setDouble(pos + 1, row.getDouble(pos)) case FloatType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setFloat(pos + 1, row.getFloat(pos)) case ShortType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getShort(pos)) case ByteType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getByte(pos)) case BooleanType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setBoolean(pos + 1, row.getBoolean(pos)) case StringType => (stmt: PreparedStatement, row: Row, pos: Int) => val strValue = row.get(pos) match { case str: UTF8String => str.toString case str: String => str } stmt.setString(pos + 1, strValue) case BinaryType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos)) case TimestampType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos)) case DateType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos)) case t: DecimalType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setBigDecimal(pos + 1, row.getDecimal(pos)) case ArrayType(et, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition .toLowerCase(Locale.ROOT).split("\\(")(0) (stmt: PreparedStatement, row: Row, pos: Int) => val array = conn.createArrayOf( typeName, row.getSeq[AnyRef](pos).toArray) stmt.setArray(pos + 1, array) case _ => (_: PreparedStatement, _: Row, pos: Int) => throw new IllegalArgumentException( s"Can't translate non-null value for field $pos") } }