org.apache.parquet.io.api.RecordConsumer Scala Examples

The following examples show how to use org.apache.parquet.io.api.RecordConsumer. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TimestampLogicalType.scala    From embulk-output-s3_parquet   with MIT License 5 votes vote down vote up
package org.embulk.output.s3_parquet.parquet

import java.time.ZoneId

import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types}
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.{
  MICROS,
  MILLIS,
  NANOS
}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.embulk.config.ConfigException
import org.embulk.output.s3_parquet.catalog.GlueDataType
import org.embulk.spi.`type`.{
  BooleanType,
  DoubleType,
  JsonType,
  LongType,
  StringType,
  TimestampType
}
import org.embulk.spi.time.{Timestamp, TimestampFormatter}
import org.embulk.spi.Column
import org.msgpack.value.Value
import org.slf4j.{Logger, LoggerFactory}

case class TimestampLogicalType(
    isAdjustedToUtc: Boolean,
    timeUnit: TimeUnit,
    timeZone: ZoneId
) extends ParquetColumnType {
  private val logger: Logger =
    LoggerFactory.getLogger(classOf[TimestampLogicalType])

  override def primitiveType(column: Column): PrimitiveType =
    column.getType match {
      case _: LongType | _: TimestampType =>
        Types
          .optional(PrimitiveTypeName.INT64)
          .as(LogicalTypeAnnotation.timestampType(isAdjustedToUtc, timeUnit))
          .named(column.getName)
      case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }

  override def glueDataType(column: Column): GlueDataType =
    column.getType match {
      case _: LongType | _: TimestampType =>
        timeUnit match {
          case MILLIS => GlueDataType.TIMESTAMP
          case MICROS | NANOS =>
            warningWhenConvertingTimestampToGlueType(GlueDataType.BIGINT)
            GlueDataType.BIGINT
        }
      case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }

  override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit =
    throw newUnsupportedMethodException("consumeBoolean")
  override def consumeString(consumer: RecordConsumer, v: String): Unit =
    throw newUnsupportedMethodException("consumeString")

  override def consumeLong(consumer: RecordConsumer, v: Long): Unit =
    consumer.addLong(v)

  override def consumeDouble(consumer: RecordConsumer, v: Double): Unit =
    throw newUnsupportedMethodException("consumeDouble")

  override def consumeTimestamp(
      consumer: RecordConsumer,
      v: Timestamp,
      formatter: TimestampFormatter
  ): Unit = timeUnit match {
    case MILLIS => consumer.addLong(v.toEpochMilli)
    case MICROS =>
      consumer.addLong(v.getEpochSecond * 1_000_000L + (v.getNano / 1_000L))
    case NANOS =>
      consumer.addLong(v.getEpochSecond * 1_000_000_000L + v.getNano)
  }

  override def consumeJson(consumer: RecordConsumer, v: Value): Unit =
    throw newUnsupportedMethodException("consumeJson")

  private def warningWhenConvertingTimestampToGlueType(
      glueType: GlueDataType
  ): Unit =
    logger.warn(
      s"timestamp(isAdjustedToUtc = $isAdjustedToUtc, timeUnit = $timeUnit) is converted" +
        s" to Glue ${glueType.name} but this is not represented correctly, because Glue" +
        s" does not support time type. Please use `catalog.column_options` to define the type."
    )
} 
Example 2
Source File: JsonLogicalType.scala    From embulk-output-s3_parquet   with MIT License 5 votes vote down vote up
package org.embulk.output.s3_parquet.parquet
import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.embulk.config.ConfigException
import org.embulk.output.s3_parquet.catalog.GlueDataType
import org.embulk.spi.Column
import org.embulk.spi.`type`.{
  BooleanType,
  DoubleType,
  JsonType,
  LongType,
  StringType,
  TimestampType
}
import org.embulk.spi.time.{Timestamp, TimestampFormatter}
import org.msgpack.value.{Value, ValueFactory}
import org.slf4j.{Logger, LoggerFactory}

object JsonLogicalType extends ParquetColumnType {
  private val logger: Logger = LoggerFactory.getLogger(JsonLogicalType.getClass)
  override def primitiveType(column: Column): PrimitiveType =
    column.getType match {
      case _: BooleanType | _: LongType | _: DoubleType | _: StringType |
          _: JsonType =>
        Types
          .optional(PrimitiveTypeName.BINARY)
          .as(LogicalTypeAnnotation.jsonType())
          .named(column.getName)
      case _: TimestampType | _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }

  override def glueDataType(column: Column): GlueDataType =
    column.getType match {
      case _: BooleanType | _: LongType | _: DoubleType | _: StringType |
          _: JsonType =>
        warningWhenConvertingJsonToGlueType(GlueDataType.STRING)
        GlueDataType.STRING
      case _: TimestampType | _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }

  override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit =
    consumeJson(consumer, ValueFactory.newBoolean(v))

  override def consumeString(consumer: RecordConsumer, v: String): Unit =
    consumeJson(consumer, ValueFactory.newString(v))

  override def consumeLong(consumer: RecordConsumer, v: Long): Unit =
    consumeJson(consumer, ValueFactory.newInteger(v))

  override def consumeDouble(consumer: RecordConsumer, v: Double): Unit =
    consumeJson(consumer, ValueFactory.newFloat(v))

  override def consumeTimestamp(
      consumer: RecordConsumer,
      v: Timestamp,
      formatter: TimestampFormatter
  ): Unit = throw newUnsupportedMethodException("consumeTimestamp")

  override def consumeJson(consumer: RecordConsumer, v: Value): Unit =
    consumer.addBinary(Binary.fromString(v.toJson))

  private def warningWhenConvertingJsonToGlueType(
      glueType: GlueDataType
  ): Unit = {
    logger.warn(
      s"json is converted" +
        s" to Glue ${glueType.name} but this is not represented correctly, because Glue" +
        s" does not support json type. Please use `catalog.column_options` to define the type."
    )
  }

} 
Example 3
Source File: LogicalTypeProxy.scala    From embulk-output-s3_parquet   with MIT License 5 votes vote down vote up
package org.embulk.output.s3_parquet.parquet

import java.time.ZoneId
import java.util.Locale

import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.MILLIS
import org.apache.parquet.schema.PrimitiveType
import org.embulk.config.ConfigException
import org.embulk.output.s3_parquet.catalog.GlueDataType
import org.embulk.spi.Column
import org.embulk.spi.time.{Timestamp, TimestampFormatter}
import org.msgpack.value.Value

object LogicalTypeProxy {
  private val DEFAULT_SCALE: Int = 0
  private val DEFAULT_BID_WIDTH: Int = 64
  private val DEFAULT_IS_SIGNED: Boolean = true
  private val DEFAULT_IS_ADJUSTED_TO_UTC: Boolean = true
  private val DEFAULT_TIME_UNIT: TimeUnit = MILLIS
  private val DEFAULT_TIME_ZONE: ZoneId = ZoneId.of("UTC")
}

case class LogicalTypeProxy(
    name: String,
    scale: Option[Int] = None,
    precision: Option[Int] = None,
    bitWidth: Option[Int] = None,
    isSigned: Option[Boolean] = None,
    isAdjustedToUtc: Option[Boolean] = None,
    timeUnit: Option[TimeUnit] = None,
    timeZone: Option[ZoneId] = None
) extends ParquetColumnType {
  private def getScale: Int = scale.getOrElse(LogicalTypeProxy.DEFAULT_SCALE)
  private def getPrecision: Int = precision.getOrElse {
    throw new ConfigException("\"precision\" must be set.")
  }
  private def getBidWith: Int =
    bitWidth.getOrElse(LogicalTypeProxy.DEFAULT_BID_WIDTH)
  private def getIsSigned: Boolean =
    isSigned.getOrElse(LogicalTypeProxy.DEFAULT_IS_SIGNED)
  private def getIsAdjustedToUtc: Boolean =
    isAdjustedToUtc.getOrElse(LogicalTypeProxy.DEFAULT_IS_ADJUSTED_TO_UTC)
  private def getTimeUnit: TimeUnit =
    timeUnit.getOrElse(LogicalTypeProxy.DEFAULT_TIME_UNIT)
  private def getTimeZone: ZoneId =
    timeZone.getOrElse(LogicalTypeProxy.DEFAULT_TIME_ZONE)

  lazy val logicalType: ParquetColumnType = {
    name.toUpperCase(Locale.ENGLISH) match {
      case "INT" => IntLogicalType(getBidWith, getIsSigned)
      case "TIMESTAMP" =>
        TimestampLogicalType(getIsAdjustedToUtc, getTimeUnit, getTimeZone)
      case "TIME" =>
        TimeLogicalType(getIsAdjustedToUtc, getTimeUnit, getTimeZone)
      case "DECIMAL" => DecimalLogicalType(getScale, getPrecision)
      case "DATE"    => DateLogicalType
      case "JSON"    => JsonLogicalType
      case _ =>
        throw new ConfigException(s"Unsupported logical_type.name: $name.")
    }
  }

  override def primitiveType(column: Column): PrimitiveType =
    logicalType.primitiveType(column)
  override def glueDataType(column: Column): GlueDataType =
    logicalType.glueDataType(column)
  override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit =
    logicalType.consumeBoolean(consumer, v)
  override def consumeString(consumer: RecordConsumer, v: String): Unit =
    logicalType.consumeString(consumer, v)
  override def consumeLong(consumer: RecordConsumer, v: Long): Unit =
    logicalType.consumeLong(consumer, v)
  override def consumeDouble(consumer: RecordConsumer, v: Double): Unit =
    logicalType.consumeDouble(consumer, v)
  override def consumeTimestamp(
      consumer: RecordConsumer,
      v: Timestamp,
      formatter: TimestampFormatter
  ): Unit = logicalType.consumeTimestamp(consumer, v, formatter)
  override def consumeJson(consumer: RecordConsumer, v: Value): Unit =
    logicalType.consumeJson(consumer, v)
} 
Example 4
Source File: DateLogicalType.scala    From embulk-output-s3_parquet   with MIT License 5 votes vote down vote up
package org.embulk.output.s3_parquet.parquet

import java.time.{Duration, Instant}

import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.embulk.config.ConfigException
import org.embulk.output.s3_parquet.catalog.GlueDataType
import org.embulk.spi.`type`.{
  BooleanType,
  DoubleType,
  JsonType,
  LongType,
  StringType,
  TimestampType
}
import org.embulk.spi.time.{Timestamp, TimestampFormatter}
import org.embulk.spi.Column
import org.msgpack.value.Value

object DateLogicalType extends ParquetColumnType {
  override def primitiveType(column: Column): PrimitiveType = {
    column.getType match {
      case _: LongType | _: TimestampType =>
        Types
          .optional(PrimitiveTypeName.INT32)
          .as(LogicalTypeAnnotation.dateType())
          .named(column.getName)
      case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }
  }

  override def glueDataType(column: Column): GlueDataType =
    column.getType match {
      case _: LongType | _: TimestampType => GlueDataType.DATE
      case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }

  override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit =
    throw newUnsupportedMethodException("consumeBoolean")

  override def consumeString(consumer: RecordConsumer, v: String): Unit =
    throw newUnsupportedMethodException("consumeString")

  override def consumeLong(consumer: RecordConsumer, v: Long): Unit =
    consumeLongAsInteger(consumer, v)

  override def consumeDouble(consumer: RecordConsumer, v: Double): Unit =
    throw newUnsupportedMethodException("consumeDouble")

  override def consumeTimestamp(
      consumer: RecordConsumer,
      v: Timestamp,
      formatter: TimestampFormatter
  ): Unit =
    consumeLongAsInteger(
      consumer,
      Duration.between(Instant.EPOCH, v.getInstant).toDays
    )

  override def consumeJson(consumer: RecordConsumer, v: Value): Unit =
    throw newUnsupportedMethodException("consumeJson")
} 
Example 5
Source File: DefaultColumnType.scala    From embulk-output-s3_parquet   with MIT License 5 votes vote down vote up
package org.embulk.output.s3_parquet.parquet

import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.embulk.config.ConfigException
import org.embulk.output.s3_parquet.catalog.GlueDataType
import org.embulk.spi.time.{Timestamp, TimestampFormatter}
import org.embulk.spi.Column
import org.embulk.spi.`type`.{
  BooleanType,
  DoubleType,
  JsonType,
  LongType,
  StringType,
  TimestampType
}
import org.msgpack.value.Value

object DefaultColumnType extends ParquetColumnType {
  override def primitiveType(column: Column): PrimitiveType =
    column.getType match {
      case _: BooleanType =>
        Types.optional(PrimitiveTypeName.BOOLEAN).named(column.getName)
      case _: LongType =>
        Types.optional(PrimitiveTypeName.INT64).named(column.getName)
      case _: DoubleType =>
        Types.optional(PrimitiveTypeName.DOUBLE).named(column.getName)
      case _: StringType =>
        Types
          .optional(PrimitiveTypeName.BINARY)
          .as(LogicalTypeAnnotation.stringType())
          .named(column.getName)
      case _: TimestampType =>
        Types
          .optional(PrimitiveTypeName.BINARY)
          .as(LogicalTypeAnnotation.stringType())
          .named(column.getName)
      case _: JsonType =>
        Types
          .optional(PrimitiveTypeName.BINARY)
          .as(LogicalTypeAnnotation.stringType())
          .named(column.getName)
      case _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }

  override def glueDataType(column: Column): GlueDataType =
    column.getType match {
      case _: BooleanType =>
        GlueDataType.BOOLEAN
      case _: LongType =>
        GlueDataType.BIGINT
      case _: DoubleType =>
        GlueDataType.DOUBLE
      case _: StringType | _: TimestampType | _: JsonType =>
        GlueDataType.STRING
      case _ =>
        throw new ConfigException(s"Unsupported column type: ${column.getName}")
    }

  override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit =
    consumer.addBoolean(v)
  override def consumeString(consumer: RecordConsumer, v: String): Unit =
    consumer.addBinary(Binary.fromString(v))
  override def consumeLong(consumer: RecordConsumer, v: Long): Unit =
    consumer.addLong(v)
  override def consumeDouble(consumer: RecordConsumer, v: Double): Unit =
    consumer.addDouble(v)
  override def consumeTimestamp(
      consumer: RecordConsumer,
      v: Timestamp,
      formatter: TimestampFormatter
  ): Unit = consumer.addBinary(Binary.fromString(formatter.format(v)))
  override def consumeJson(consumer: RecordConsumer, v: Value): Unit =
    consumer.addBinary(Binary.fromString(v.toJson))
} 
Example 6
Source File: MockParquetRecordConsumer.scala    From embulk-output-s3_parquet   with MIT License 5 votes vote down vote up
package org.embulk.output.s3_parquet.parquet

import org.apache.parquet.io.api.{Binary, RecordConsumer}

case class MockParquetRecordConsumer() extends RecordConsumer {
  case class Data private (messages: Seq[Message] = Seq()) {
    def toData: Seq[Seq[Any]] = messages.map(_.toData)
  }
  case class Message private (fields: Seq[Field] = Seq()) {
    def toData: Seq[Any] = {
      val maxIndex: Int = fields.maxBy(_.index).index
      val raw: Map[Int, Any] = fields.map(f => f.index -> f.value).toMap
      0.to(maxIndex).map(idx => raw.get(idx).orNull)
    }
  }
  case class Field private (index: Int = 0, value: Any = null)

  private var _data: Data = Data()
  private var _message: Message = Message()
  private var _field: Field = Field()

  override def startMessage(): Unit = _message = Message()
  override def endMessage(): Unit =
    _data = _data.copy(messages = _data.messages :+ _message)
  override def startField(field: String, index: Int): Unit =
    _field = Field(index = index)
  override def endField(field: String, index: Int): Unit =
    _message = _message.copy(fields = _message.fields :+ _field)
  override def startGroup(): Unit = throw new UnsupportedOperationException
  override def endGroup(): Unit = throw new UnsupportedOperationException
  override def addInteger(value: Int): Unit =
    _field = _field.copy(value = value)
  override def addLong(value: Long): Unit = _field = _field.copy(value = value)
  override def addBoolean(value: Boolean): Unit =
    _field = _field.copy(value = value)
  override def addBinary(value: Binary): Unit =
    _field = _field.copy(value = value)
  override def addFloat(value: Float): Unit =
    _field = _field.copy(value = value)
  override def addDouble(value: Double): Unit =
    _field = _field.copy(value = value)

  def writingMessage(f: => Unit): Unit = {
    startMessage()
    f
    endMessage()
  }
  def writingField(field: String, index: Int)(f: => Unit): Unit = {
    startField(field, index)
    f
    endField(field, index)
  }
  def writingSampleField(f: => Unit): Unit = {
    writingMessage {
      writingField("a", 0)(f)
    }
  }
  def data: Seq[Seq[Any]] = _data.toData
} 
Example 7
Source File: RowWriteSupport.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.parquet

import com.sksamuel.exts.Logging
import io.eels.Row
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.FinalizedWriteContext
import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.MessageType

import scala.collection.JavaConverters._
import scala.math.BigDecimal.RoundingMode.RoundingMode

// implementation of WriteSupport for Row's used by the native ParquetWriter
class RowWriteSupport(schema: MessageType,
                      roundingMode: RoundingMode,
                      metadata: Map[String, String]) extends WriteSupport[Row] with Logging {
  logger.trace(s"Created parquet row write support for schema message type $schema")

  private var writer: RowWriter = _

  override def finalizeWrite(): FinalizedWriteContext = new FinalizedWriteContext(metadata.asJava)

  def init(configuration: Configuration): WriteSupport.WriteContext = {
    new WriteSupport.WriteContext(schema, new java.util.HashMap())
  }

  def prepareForWrite(record: RecordConsumer) {
    writer = new RowWriter(record, roundingMode)
  }

  def write(row: Row) {
    writer.write(row)
  }
}

class RowWriter(record: RecordConsumer, roundingMode: RoundingMode) {

  def write(row: Row): Unit = {
    record.startMessage()
    val writer = new StructRecordWriter(row.schema, roundingMode, false)
    writer.write(record, row.values)
    record.endMessage()
  }
} 
Example 8
Source File: DirectParquetWriter.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters._

import org.apache.hadoop.conf
import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.ParquetWriter
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.{MessageType, MessageTypeParser}

  private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String])
    extends WriteSupport[RecordBuilder] {

    private var recordConsumer: RecordConsumer = _
    //初始化
    override def init(configuration: conf.Configuration): WriteContext = {
      new WriteContext(schema, metadata.asJava)
    }
    //写操作
    override def write(buildRecord: RecordBuilder): Unit = {
      recordConsumer.startMessage()
      buildRecord(recordConsumer)
      recordConsumer.endMessage()
    }
    //准备写
    override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
      this.recordConsumer = recordConsumer
    }
  }
  //直接写入
  def writeDirect
      (path: String, schema: String, metadata: Map[String, String] = Map.empty)
      (f: ParquetWriter[RecordBuilder] => Unit): Unit = {
    //println("==1111==")
    val messageType = MessageTypeParser.parseMessageType(schema)
    val writeSupport = new DirectWriteSupport(messageType, metadata)
    // println("==2222==")
    val parquetWriter = new ParquetWriter[RecordBuilder](new Path(path), writeSupport)
     // println("==3333==")
    try f(parquetWriter) finally parquetWriter.close()
  }
  //消息
  def message(writer: ParquetWriter[RecordBuilder])(builder: RecordBuilder): Unit = {
    writer.write(builder)
  }
  //分组
  def group(consumer: RecordConsumer)(f: => Unit): Unit = {
    consumer.startGroup()
    f
    consumer.endGroup()
  }
  //字段
  def field(consumer: RecordConsumer, name: String, index: Int = 0)(f: => Unit): Unit = {
    consumer.startField(name, index)
    f
    consumer.endField(name, index)
  }
} 
Example 9
Source File: ParquetCompatibilityTest.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter}
import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.{MessageType, MessageTypeParser}

import org.apache.spark.sql.QueryTest


  def writeDirect(
      path: String,
      schema: String,
      metadata: Map[String, String],
      recordWriters: (RecordConsumer => Unit)*): Unit = {
    val messageType = MessageTypeParser.parseMessageType(schema)
    val writeSupport = new DirectWriteSupport(messageType, metadata)
    val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport)
    try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close()
  }
} 
Example 10
Source File: StructWriteSupport.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.parquet

import com.landoop.streamreactor.connect.hive._
import org.apache.hadoop.conf.Configuration
import org.apache.kafka.connect.data.{Schema, Struct}
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.FinalizedWriteContext
import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.MessageType

import scala.collection.JavaConverters._

// derived from Apache Spark's parquet write support, archive and license here:
// https://github.com/apache/spark/blob/21a7bfd5c324e6c82152229f1394f26afeae771c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
class StructWriteSupport(schema: Schema) extends WriteSupport[Struct] {

  private val logger = org.slf4j.LoggerFactory.getLogger(getClass.getName)
  private val schemaName = if (schema.name() == null) "schema" else schema.name()
  private val parquetSchema: MessageType = ParquetSchemas.toParquetMessage(schema, schemaName)

  private val metadata = new java.util.HashMap[String, String]()
  metadata.put("written_by", "streamreactor")

  // The Parquet `RecordConsumer` to which all structs are written
  private var consumer: RecordConsumer = _

  type ValueWriter = (Any) => Unit

  override def init(conf: Configuration): WriteSupport.WriteContext = new WriteSupport.WriteContext(parquetSchema, new java.util.HashMap[String, String])
  override def finalizeWrite(): WriteSupport.FinalizedWriteContext = new FinalizedWriteContext(metadata)
  override def prepareForWrite(consumer: RecordConsumer): Unit = this.consumer = consumer

  override def write(struct: Struct): Unit = {
    writeMessage {
      writeStructFields(struct)
    }
  }

  private def writeStructFields(struct: Struct): Unit = {
    for ((field, index) <- struct.schema.fields.asScala.zipWithIndex) {
      val value = struct.get(field)
      if (value != null) {
        val writer = valueWriter(field.schema())
        writeField(field.name, index) {
          writer(value)
        }
      }
    }
  }

  def valueWriter(schema: Schema): ValueWriter = {
    // todo perhaps introduce something like spark's SpecializedGetters
    schema.`type`() match {
      case Schema.Type.BOOLEAN => value => consumer.addBoolean(value.asInstanceOf[Boolean])
      case Schema.Type.INT8 | Schema.Type.INT16 | Schema.Type.INT32 => value => consumer.addInteger(value.toString.toInt)
      case Schema.Type.INT64 => value => consumer.addLong(value.toString.toLong)
      case Schema.Type.STRING => value => consumer.addBinary(Binary.fromReusedByteArray(value.toString.getBytes))
      case Schema.Type.FLOAT32 => value => consumer.addFloat(value.toString.toFloat)
      case Schema.Type.FLOAT64 => value => consumer.addDouble(value.toString.toDouble)
      case Schema.Type.STRUCT => value => {
        logger.debug(s"Writing nested struct")
        val struct = value.asInstanceOf[Struct]
        writeGroup {
          schema.fields.asScala
            .map { field => field -> struct.get(field) }
            .zipWithIndex.foreach { case ((field, v), k) =>
            writeField(field.name, k) {
              valueWriter(field.schema)(v)
            }
          }
        }
      }
      case _ => throw UnsupportedSchemaType(schema.`type`.toString)
    }
  }

  private def writeMessage(f: => Unit): Unit = {
    consumer.startMessage()
    f
    consumer.endMessage()
  }

  private def writeGroup(f: => Unit): Unit = {
    consumer.startGroup()
    // consumer.startMessage()
    f
    //consumer.endMessage()
    consumer.endGroup()
  }

  private def writeField(name: String, k: Int)(f: => Unit): Unit = {
    consumer.startField(name, k)
    f
    consumer.endField(name, k)
  }
} 
Example 11
Source File: StructWriteSupport.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.parquet

import com.landoop.streamreactor.connect.hive._
import org.apache.hadoop.conf.Configuration
import org.apache.kafka.connect.data.{Schema, Struct}
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.FinalizedWriteContext
import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.MessageType

import scala.collection.JavaConverters._

// derived from Apache Spark's parquet write support, archive and license here:
// https://github.com/apache/spark/blob/21a7bfd5c324e6c82152229f1394f26afeae771c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
class StructWriteSupport(schema: Schema) extends WriteSupport[Struct] {

  private val logger = org.slf4j.LoggerFactory.getLogger(getClass.getName)
  private val schemaName = if (schema.name() == null) "schema" else schema.name()
  private val parquetSchema: MessageType = ParquetSchemas.toParquetMessage(schema, schemaName)

  private val metadata = new java.util.HashMap[String, String]()
  metadata.put("written_by", "streamreactor")

  // The Parquet `RecordConsumer` to which all structs are written
  private var consumer: RecordConsumer = _

  type ValueWriter = (Any) => Unit

  override def init(conf: Configuration): WriteSupport.WriteContext = new WriteSupport.WriteContext(parquetSchema, new java.util.HashMap[String, String])
  override def finalizeWrite(): WriteSupport.FinalizedWriteContext = new FinalizedWriteContext(metadata)
  override def prepareForWrite(consumer: RecordConsumer): Unit = this.consumer = consumer

  override def write(struct: Struct): Unit = {
    writeMessage {
      writeStructFields(struct)
    }
  }

  private def writeStructFields(struct: Struct): Unit = {
    for ((field, index) <- struct.schema.fields.asScala.zipWithIndex) {
      val value = struct.get(field)
      if (value != null) {
        val writer = valueWriter(field.schema())
        writeField(field.name, index) {
          writer(value)
        }
      }
    }
  }

  def valueWriter(schema: Schema): ValueWriter = {
    // todo perhaps introduce something like spark's SpecializedGetters
    schema.`type`() match {
      case Schema.Type.BOOLEAN => value => consumer.addBoolean(value.asInstanceOf[Boolean])
      case Schema.Type.INT8 | Schema.Type.INT16 | Schema.Type.INT32 => value => consumer.addInteger(value.toString.toInt)
      case Schema.Type.INT64 => value => consumer.addLong(value.toString.toLong)
      case Schema.Type.STRING => value => consumer.addBinary(Binary.fromReusedByteArray(value.toString.getBytes))
      case Schema.Type.FLOAT32 => value => consumer.addFloat(value.toString.toFloat)
      case Schema.Type.FLOAT64 => value => consumer.addDouble(value.toString.toDouble)
      case Schema.Type.STRUCT => value => {
        logger.debug(s"Writing nested struct")
        val struct = value.asInstanceOf[Struct]
        writeGroup {
          schema.fields.asScala
            .map { field => field -> struct.get(field) }
            .zipWithIndex.foreach { case ((field, v), k) =>
            writeField(field.name, k) {
              valueWriter(field.schema)(v)
            }
          }
        }
      }
      case _ => throw UnsupportedSchemaType(schema.`type`.toString)
    }
  }

  private def writeMessage(f: => Unit): Unit = {
    consumer.startMessage()
    f
    consumer.endMessage()
  }

  private def writeGroup(f: => Unit): Unit = {
    consumer.startGroup()
    // consumer.startMessage()
    f
    //consumer.endMessage()
    consumer.endGroup()
  }

  private def writeField(name: String, k: Int)(f: => Unit): Unit = {
    consumer.startField(name, k)
    f
    consumer.endField(name, k)
  }
}