org.apache.spark.sql.sources.TableScan Scala Examples

The following examples show how to use org.apache.spark.sql.sources.TableScan. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: PointCloudRelation.scala    From geotrellis-pointcloud   with Apache License 2.0 5 votes vote down vote up
package geotrellis.pointcloud.spark.datasource

import geotrellis.pointcloud.spark.store.hadoop._
import geotrellis.pointcloud.spark.store.hadoop.HadoopPointCloudRDD.{Options => HadoopOptions}
import geotrellis.pointcloud.util.Filesystem
import geotrellis.proj4.CRS
import geotrellis.store.hadoop.util.HdfsUtils
import geotrellis.vector.Extent

import cats.implicits._
import io.pdal._
import io.circe.syntax._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}

import java.io.File

import scala.collection.JavaConverters._

// This class has to be serializable since it is shipped over the network.
class PointCloudRelation(
  val sqlContext: SQLContext,
  path: String,
  options: HadoopOptions
) extends BaseRelation with TableScan with Serializable {

  @transient implicit lazy val sc: SparkContext = sqlContext.sparkContext

  // TODO: switch between HadoopPointCloudRDD and S3PointcCloudRDD
  lazy val isS3: Boolean = path.startsWith("s3")

  override def schema: StructType = {
    lazy val (local, fixedPath) =
      if(path.startsWith("s3") || path.startsWith("hdfs")) {
        val tmpDir = Filesystem.createDirectory()
        val remotePath = new Path(path)
        // copy remote file into local tmp dir
        val localPath = new File(tmpDir, remotePath.getName)
        HdfsUtils.copyPath(remotePath, new Path(s"file:///${localPath.getAbsolutePath}"), sc.hadoopConfiguration)
        (true, localPath.toString)
      } else (false, path)

    val localPipeline =
      options.pipeline
        .hcursor
        .downField("pipeline").downArray
        .downField("filename").withFocus(_ => fixedPath.asJson)
        .top.fold(options.pipeline)(identity)

    val pl = Pipeline(localPipeline.noSpaces)
    if (pl.validate()) pl.execute()
    val pointCloud = try {
      pl.getPointViews().next().getPointCloud(0)
    } finally {
      pl.close()
      if(local) println(new File(fixedPath).delete)
    }

    val rdd = HadoopPointCloudRDD(new Path(path), options)

    val md: (Option[Extent], Option[CRS]) =
      rdd
        .map { case (header, _) => (header.projectedExtent3D.map(_.extent3d.toExtent), header.crs) }
        .reduce { case ((e1, c), (e2, _)) => ((e1, e2).mapN(_ combine _), c) }

    val metadata = new MetadataBuilder().putString("metadata", md.asJson.noSpaces).build

    pointCloud.deriveSchema(metadata)
  }

  override def buildScan(): RDD[Row] = {
    val rdd = HadoopPointCloudRDD(new Path(path), options)
    rdd.flatMap { _._2.flatMap { pc => pc.readAll.toList.map { k => Row(k: _*) } } }
  }
} 
Example 2
Source File: XmlRelation.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml

import java.io.IOException

import org.apache.hadoop.fs.Path

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.sources.{PrunedScan, InsertableRelation, BaseRelation, TableScan}
import org.apache.spark.sql.types._
import com.databricks.spark.xml.util.{InferSchema, XmlFile}
import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlRelation protected[spark] (
    baseRDD: () => RDD[String],
    location: Option[String],
    parameters: Map[String, String],
    userSchema: StructType = null)(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with InsertableRelation
  with PrunedScan {

  private val options = XmlOptions(parameters)

  override val schema: StructType = {
    Option(userSchema).getOrElse {
      InferSchema.infer(
        baseRDD(),
        options)
    }
  }

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    val requiredFields = requiredColumns.map(schema(_))
    val requestedSchema = StructType(requiredFields)
    StaxXmlParser.parse(
      baseRDD(),
      requestedSchema,
      options)
  }

  // The function below was borrowed from JSONRelation
  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    val filesystemPath = location match {
      case Some(p) => new Path(p)
      case None =>
        throw new IOException(s"Cannot INSERT into table with no path defined")
    }

    val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)

    if (overwrite) {
      try {
        fs.delete(filesystemPath, true)
      } catch {
        case e: IOException =>
          throw new IOException(
            s"Unable to clear output directory ${filesystemPath.toString} prior"
              + s" to INSERT OVERWRITE a XML table:\n${e.toString}")
      }
      // Write the data. We assume that schema isn't changed, and we won't update it.
      XmlFile.saveAsXmlFile(data, filesystemPath.toString, parameters)
    } else {
      throw new IllegalArgumentException("XML tables only support INSERT OVERWRITE for now.")
    }
  }
} 
Example 3
Source File: ExcelRelation.scala    From spark-hadoopoffice-ds   with Apache License 2.0 5 votes vote down vote up
package org.zuinnote.spark.office.excel

import scala.collection.JavaConversions._

import org.apache.spark.sql.sources.{ BaseRelation, TableScan }
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.SQLContext

import org.apache.spark.sql._
import org.apache.spark.rdd.RDD

import org.apache.hadoop.conf._
import org.apache.hadoop.mapreduce._

import org.apache.commons.logging.LogFactory
import org.apache.commons.logging.Log

import org.zuinnote.hadoop.office.format.common.dao._
import org.zuinnote.hadoop.office.format.mapreduce._

import org.zuinnote.spark.office.excel.util.ExcelFile


  override def buildScan: RDD[Row] = {
    // read ExcelRows
    val excelRowsRDD = ExcelFile.load(sqlContext, location, hadoopParams)
    // map to schema
    val schemaFields = schema.fields
    excelRowsRDD.flatMap(excelKeyValueTuple => {
      // map the Excel row data structure to a Spark SQL schema
      val rowArray = new Array[Any](excelKeyValueTuple._2.get.length)
      var i = 0;
      for (x <- excelKeyValueTuple._2.get) { // parse through the SpreadSheetCellDAO
        val spreadSheetCellDAOStructArray = new Array[String](schemaFields.length)
        val currentSpreadSheetCellDAO: Array[SpreadSheetCellDAO] = excelKeyValueTuple._2.get.asInstanceOf[Array[SpreadSheetCellDAO]]
        spreadSheetCellDAOStructArray(0) = currentSpreadSheetCellDAO(i).getFormattedValue
        spreadSheetCellDAOStructArray(1) = currentSpreadSheetCellDAO(i).getComment
        spreadSheetCellDAOStructArray(2) = currentSpreadSheetCellDAO(i).getFormula
        spreadSheetCellDAOStructArray(3) = currentSpreadSheetCellDAO(i).getAddress
        spreadSheetCellDAOStructArray(4) = currentSpreadSheetCellDAO(i).getSheetName
        // add row representing one Excel row
        rowArray(i) = spreadSheetCellDAOStructArray
        i += 1
      }
      Some(Row.fromSeq(rowArray))
    })

  }

} 
Example 4
Source File: EventHubsRelation.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.eventhubs

import org.apache.spark.eventhubs.rdd.{ EventHubsRDD, OffsetRange }
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ Row, SQLContext }
import org.apache.spark.sql.sources.{ BaseRelation, TableScan }
import org.apache.spark.sql.types.StructType

import scala.language.postfixOps


private[eventhubs] class EventHubsRelation(override val sqlContext: SQLContext,
                                           parameters: Map[String, String])
    extends BaseRelation
    with TableScan
    with Logging {

  import org.apache.spark.eventhubs._

  private val ehConf = EventHubsConf.toConf(parameters)
  private val eventHubClient = EventHubsSourceProvider.clientFactory(parameters)(ehConf)

  override def schema: StructType = EventHubsSourceProvider.eventHubsSchema

  override def buildScan(): RDD[Row] = {
    val partitionCount: Int = eventHubClient.partitionCount

    val fromSeqNos = eventHubClient.translate(ehConf, partitionCount)
    val untilSeqNos = eventHubClient.translate(ehConf, partitionCount, useStart = false)

    require(fromSeqNos.forall(f => f._2 >= 0L),
            "Currently only sequence numbers can be passed in your starting positions.")
    require(untilSeqNos.forall(u => u._2 >= 0L),
            "Currently only sequence numbers can be passed in your ending positions.")

    val offsetRanges = untilSeqNos.keySet.map { p =>
      val fromSeqNo = fromSeqNos
        .getOrElse(p, throw new IllegalStateException(s"$p doesn't have a fromSeqNo"))
      val untilSeqNo = untilSeqNos(p)
      OffsetRange(ehConf.name, p, fromSeqNo, untilSeqNo, None)
    }.toArray
    eventHubClient.close()

    logInfo(
      "GetBatch generating RDD of with offsetRanges: " +
        offsetRanges.sortBy(_.nameAndPartition.toString).mkString(", "))

    val rdd = EventHubsSourceProvider.toInternalRow(
      new EventHubsRDD(sqlContext.sparkContext, ehConf.trimmed, offsetRanges))
    sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = false).rdd
  }
} 
Example 5
Source File: DruidRelation.scala    From spark-druid-olap   with Apache License 2.0 5 votes vote down vote up
package org.sparklinedata.druid

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, ExprId}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}
import org.joda.time.Interval
import org.sparklinedata.druid.metadata.DruidRelationInfo


case class DruidOperatorAttribute(exprId : ExprId, name : String, dataType : DataType,
                                  tf: String = null)



  override val needConversion: Boolean = false

  override def schema: StructType =
    dQuery.map(_.schema(info)).getOrElse(info.sourceDF(sqlContext).schema)

  def buildInternalScan : RDD[InternalRow] =
    dQuery.map(new DruidRDD(sqlContext, info, _)).getOrElse(
      info.sourceDF(sqlContext).queryExecution.toRdd
    )

  override def buildScan(): RDD[Row] =
    buildInternalScan.asInstanceOf[RDD[Row]]

  override def toString : String = {
    if (dQuery.isDefined) {
      s"DruidQuery(${System.identityHashCode(dQuery)}): ${Utils.queryToString(dQuery.get)}"
    } else {
      info.toString
    }
  }
} 
Example 6
Source File: DatasetRelation.scala    From spark-sftp   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.sftp

import com.databricks.spark.avro._
import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType


case class DatasetRelation(
    fileLocation: String,
    fileType: String,
    inferSchema: String,
    header: String,
    delimiter: String,
    quote: String,
    escape: String,
    multiLine: String,
    rowTag: String,
    customSchema: StructType,
    sqlContext: SQLContext) extends BaseRelation with TableScan {

    private val logger = Logger.getLogger(classOf[DatasetRelation])

    val df = read()

    private def read(): DataFrame = {
      var dataframeReader = sqlContext.read
      if (customSchema != null) {
        dataframeReader = dataframeReader.schema(customSchema)
      }

      var df: DataFrame = null

      df = fileType match {
        case "avro" => dataframeReader.avro(fileLocation)
        case "txt" => dataframeReader.format("text").load(fileLocation)
        case "xml" => dataframeReader.format(constants.xmlClass)
          .option(constants.xmlRowTag, rowTag)
          .load(fileLocation)
        case "csv" => dataframeReader.
          option("header", header).
          option("delimiter", delimiter).
          option("quote", quote).
          option("escape", escape).
          option("multiLine", multiLine).
          option("inferSchema", inferSchema).
          csv(fileLocation)
        case _ => dataframeReader.format(fileType).load(fileLocation)
      }
     df
    }

    override def schema: StructType = {
      df.schema
    }

    override def buildScan(): RDD[Row] = {
      df.rdd
    }

} 
Example 7
Source File: SpreadsheetRelation.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.datasource.google.spreadsheet

import mimir.exec.spark.datasource.google.spreadsheet.SparkSpreadsheetService.SparkSpreadsheetContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class SpreadsheetRelation protected[spark] (
                                                  context:SparkSpreadsheetContext,
                                                  spreadsheetName: String,
                                                  worksheetName: String,
                                                  userSchema: Option[StructType] = None)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan with InsertableRelation {

  import mimir.exec.spark.datasource.google.spreadsheet.SparkSpreadsheetService._

  private val fieldMap = scala.collection.mutable.Map[String, String]()
  override def schema: StructType = userSchema.getOrElse(inferSchema())

  private lazy val aWorksheet: SparkWorksheet =
    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(aWorksheet) => aWorksheet
      case Left(e) => throw e
    }

  private lazy val rows: Seq[Map[String, String]] = aWorksheet.rows

  private[spreadsheet] def findWorksheet(spreadsheetName: String, worksheetName: String)(implicit ctx: SparkSpreadsheetContext): Either[Throwable, SparkWorksheet] =
    for {
      sheet <- findSpreadsheet(spreadsheetName).toRight(new RuntimeException(s"no such spreadsheet: $spreadsheetName")).right
      worksheet <- sheet.findWorksheet(worksheetName).toRight(new RuntimeException(s"no such worksheet: $worksheetName")).right
    } yield worksheet

  override def buildScan(): RDD[Row] = {
    val aSchema = schema
    val schemaMap = fieldMap.toMap
    sqlContext.sparkContext.makeRDD(rows).mapPartitions { iter =>
      iter.map { m =>
        var index = 0
        val rowArray = new Array[Any](aSchema.fields.length)
        while(index < aSchema.fields.length) {
          val field = aSchema.fields(index)
          rowArray(index) = if (m.contains(field.name)) {
            TypeCast.castTo(m(field.name), field.dataType, field.nullable)
          } else if (schemaMap.contains(field.name) && m.contains(schemaMap(field.name))) {
            TypeCast.castTo(m(schemaMap(field.name)), field.dataType, field.nullable)
          } else {
            null
          }
          index += 1
        }
        Row.fromSeq(rowArray)
      }
    }
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    if(!overwrite) {
      sys.error("Spreadsheet tables only support INSERT OVERWRITE for now.")
    }

    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(w) =>
        w.updateCells(data.schema, data.collect().toList, Util.toRowData)
      case Left(e) =>
        throw e
    }
  }

  def sanitizeColumnName(name: String): String =
  {
    name
      .replaceAll("[^a-zA-Z0-9]+", "_")    // Replace sequences of non-alphanumeric characters with underscores
      .replaceAll("_+$", "")               // Strip trailing underscores
      .replaceAll("^[0-9_]+", "")          // Strip leading underscores and digits
  }

  private def inferSchema(): StructType =
    StructType(aWorksheet.headers.toList.map { fieldName => {
      val sanitizedName = sanitizeColumnName(fieldName)
      fieldMap.put(sanitizedName, fieldName)
      StructField(sanitizedName, StringType, true)
    }})

} 
Example 8
Source File: PulsarRelation.scala    From pulsar-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.pulsar

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.json.JSONOptionsInRead
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType


private[pulsar] class PulsarRelation(
    override val sqlContext: SQLContext,
    override val schema: StructType,
    schemaInfo: SchemaInfoSerializable,
    adminUrl: String,
    clientConf: ju.Map[String, Object],
    readerConf: ju.Map[String, Object],
    startingOffset: SpecificPulsarOffset,
    endingOffset: SpecificPulsarOffset,
    pollTimeoutMs: Int,
    failOnDataLoss: Boolean,
    subscriptionNamePrefix: String,
    jsonOptions: JSONOptionsInRead)
    extends BaseRelation
    with TableScan
    with Logging {

  import PulsarSourceUtils._

  val reportDataLoss = reportDataLossFunc(failOnDataLoss)

  override def buildScan(): RDD[Row] = {
    val fromTopicOffsets = startingOffset.topicOffsets
    val endTopicOffsets = endingOffset.topicOffsets

    if (fromTopicOffsets.keySet != endTopicOffsets.keySet) {
      val fromTopics = fromTopicOffsets.keySet.toList.sorted.mkString(",")
      val endTopics = endTopicOffsets.keySet.toList.sorted.mkString(",")
      throw new IllegalStateException(
        "different topics " +
          s"for starting offsets topics[${fromTopics}] and " +
          s"ending offsets topics[${endTopics}]")
    }

    val offsetRanges = endTopicOffsets.keySet
      .map { tp =>
        val fromOffset = fromTopicOffsets.getOrElse(tp, {
          // this shouldn't happen since we had checked it
          throw new IllegalStateException(s"$tp doesn't have a from offset")
        })
        val untilOffset = endTopicOffsets(tp)
        PulsarOffsetRange(tp, fromOffset, untilOffset, None)
      }
      .filter { range =>
        if (range.untilOffset.compareTo(range.fromOffset) < 0) {
          reportDataLoss(
            s"${range.topic}'s offset was changed " +
              s"from ${range.fromOffset} to ${range.untilOffset}, " +
              "some data might has been missed")
          false
        } else {
          true
        }
      }
      .toSeq

    val rdd = new PulsarSourceRDD4Batch(
      sqlContext.sparkContext,
      schemaInfo,
      adminUrl,
      clientConf,
      readerConf,
      offsetRanges,
      pollTimeoutMs,
      failOnDataLoss,
      subscriptionNamePrefix,
      jsonOptions
    )
    sqlContext.internalCreateDataFrame(rdd.setName("pulsar"), schema).rdd
  }
} 
Example 9
Source File: GDBRelation.scala    From spark-gdb   with Apache License 2.0 5 votes vote down vote up
package com.esri.gdb

import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}


case class GDBRelation(gdbPath: String, gdbName: String, numPartition: Int)
                      (@transient val sqlContext: SQLContext)
  extends BaseRelation with Logging with TableScan {

  override val schema = inferSchema()

  private def inferSchema() = {
    val sc = sqlContext.sparkContext
    GDBTable.findTable(gdbPath, gdbName, sc.hadoopConfiguration) match {
      case Some(catTab) => {
        val table = GDBTable(gdbPath, catTab.hexName, sc.hadoopConfiguration)
        try {
          table.schema()
        } finally {
          table.close()
        }
      }
      case _ => {
        log.error(s"Cannot find '$gdbName' in $gdbPath, creating an empty schema !")
        StructType(Seq.empty[StructField])
      }
    }
  }

  override def buildScan(): RDD[Row] = {
    GDBRDD(sqlContext.sparkContext, gdbPath, gdbName, numPartition)
  }
} 
Example 10
Source File: SpreadsheetRelation.scala    From spark-google-spreadsheets   with Apache License 2.0 5 votes vote down vote up
package com.github.potix2.spark.google.spreadsheets

import com.github.potix2.spark.google.spreadsheets.SparkSpreadsheetService.SparkSpreadsheetContext
import com.github.potix2.spark.google.spreadsheets.util._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class SpreadsheetRelation protected[spark] (
                                                  context:SparkSpreadsheetContext,
                                                  spreadsheetName: String,
                                                  worksheetName: String,
                                                  userSchema: Option[StructType] = None)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan with InsertableRelation {

  import com.github.potix2.spark.google.spreadsheets.SparkSpreadsheetService._

  override def schema: StructType = userSchema.getOrElse(inferSchema())

  private lazy val aWorksheet: SparkWorksheet =
    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(aWorksheet) => aWorksheet
      case Left(e) => throw e
    }

  private lazy val rows: Seq[Map[String, String]] = aWorksheet.rows

  private[spreadsheets] def findWorksheet(spreadsheetName: String, worksheetName: String)(implicit ctx: SparkSpreadsheetContext): Either[Throwable, SparkWorksheet] =
    for {
      sheet <- findSpreadsheet(spreadsheetName).toRight(new RuntimeException(s"no such spreadsheet: $spreadsheetName")).right
      worksheet <- sheet.findWorksheet(worksheetName).toRight(new RuntimeException(s"no such worksheet: $worksheetName")).right
    } yield worksheet

  override def buildScan(): RDD[Row] = {
    val aSchema = schema
    sqlContext.sparkContext.makeRDD(rows).mapPartitions { iter =>
      iter.map { m =>
        var index = 0
        val rowArray = new Array[Any](aSchema.fields.length)
        while(index < aSchema.fields.length) {
          val field = aSchema.fields(index)
          rowArray(index) = if (m.contains(field.name)) {
            TypeCast.castTo(m(field.name), field.dataType, field.nullable)
          } else {
            null
          }
          index += 1
        }
        Row.fromSeq(rowArray)
      }
    }
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    if(!overwrite) {
      sys.error("Spreadsheet tables only support INSERT OVERWRITE for now.")
    }

    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(w) =>
        w.updateCells(data.schema, data.collect().toList, Util.toRowData)
      case Left(e) =>
        throw e
    }
  }

  private def inferSchema(): StructType =
    StructType(aWorksheet.headers.toList.map { fieldName =>
      StructField(fieldName, StringType, nullable = true)
    })

} 
Example 11
Source File: HelloWorldDataSource.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package com.github.dnvriend.spark.datasources.helloworld

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{ BaseRelation, DataSourceRegister, RelationProvider, TableScan }
import org.apache.spark.sql.types.{ StringType, StructField, StructType }
import org.apache.spark.sql.{ Row, SQLContext }

class HelloWorldDataSource extends RelationProvider with DataSourceRegister with Serializable {
  override def shortName(): String = "helloworld"

  override def hashCode(): Int = getClass.hashCode()

  override def equals(other: scala.Any): Boolean = other.isInstanceOf[HelloWorldDataSource]

  override def toString: String = "HelloWorldDataSource"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    val path = parameters.get("path")
    path match {
      case Some(p) => new HelloWorldRelationProvider(sqlContext, p, parameters)
      case _       => throw new IllegalArgumentException("Path is required for Tickets datasets")
    }
  }
}

class HelloWorldRelationProvider(val sqlContext: SQLContext, path: String, parameters: Map[String, String]) extends BaseRelation with TableScan {
  import sqlContext.implicits._

  override def schema: StructType = StructType(Array(
    StructField("key", StringType, nullable = false),
    StructField("value", StringType, nullable = true)
  ))

  override def buildScan(): RDD[Row] =
    Seq(
      "path" -> path,
      "message" -> parameters.getOrElse("message", ""),
      "name" -> s"Hello ${parameters.getOrElse("name", "")}",
      "hello_world" -> "Hello World!"
    ).toDF.rdd
} 
Example 12
Source File: TensorflowRelation.scala    From ecosystem   with Apache License 2.0 5 votes vote down vote up
package org.tensorflow.spark.datasources.tfrecords

import org.apache.hadoop.io.{BytesWritable, NullWritable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext, SparkSession}
import org.tensorflow.example.{SequenceExample, Example}
import org.tensorflow.hadoop.io.TFRecordFileInputFormat
import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowDecoder


case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None)
                             (@transient val session: SparkSession) extends BaseRelation with TableScan {

  //Import TFRecords as DataFrame happens here
  lazy val (tfRdd, tfSchema) = {
    val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable])

    val recordType = options.getOrElse("recordType", "Example")

    recordType match {
      case "Example" =>
        val exampleRdd = rdd.map{case (bytesWritable, nullWritable) =>
          Example.parseFrom(bytesWritable.getBytes)
        }
        val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(exampleRdd))
        val rowRdd = exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeExample(example, finalSchema))
        (rowRdd, finalSchema)
      case "SequenceExample" =>
        val sequenceExampleRdd = rdd.map{case (bytesWritable, nullWritable) =>
          SequenceExample.parseFrom(bytesWritable.getBytes)
        }
        val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(sequenceExampleRdd))
        val rowRdd = sequenceExampleRdd.map(example => DefaultTfRecordRowDecoder.decodeSequenceExample(example, finalSchema))
        (rowRdd, finalSchema)
      case _ =>
        throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample")
    }
  }

  override def sqlContext: SQLContext = session.sqlContext

  override def schema: StructType = tfSchema

  override def buildScan(): RDD[Row] = tfRdd
}