org.apache.commons.logging.LogFactory Scala Examples
The following examples show how to use org.apache.commons.logging.LogFactory.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: Init.scala From cave with MIT License | 5 votes |
package init import java.net.InetAddress import java.util.UUID import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory import com.cave.metrics.data.AwsConfig import com.cave.metrics.data.influxdb.{InfluxConfiguration, InfluxDataSink} import com.cave.metrics.data.kinesis.RecordProcessorFactory import com.typesafe.config.ConfigFactory import org.apache.commons.logging.LogFactory import play.api.Play import scala.util.Try object Init { // Docker should place the stream name in this environment variable final val EnvStreamName = "STREAM_NAME" // The name of this application for Kinesis Client Library final val ApplicationName = "cave-db-worker" // CloudWatch Reporter parameters final val MetricsNamespace = s"metrics-$ApplicationName" final val MetricsBufferTime = 1000L final val MetricsBufferSize = 200 final val ThreadWaitTimeout = 10000L private val Log = LogFactory.getLog("db-writer-app") val worker = createWorker() val workerThread = new Thread(worker) def start(): Unit = { workerThread.start() } def shutdown(): Unit = { worker.shutdown() Try (workerThread.join(ThreadWaitTimeout)) recover { case e: Exception => Log.info(s"Caught exception while joining worker thread: $e") } } private[this] def createWorker(): Worker = { val configuration = Play.current.configuration val serviceConfFile = configuration.getString("serviceConf").getOrElse("db-writer-service.conf") val kinesisAppName = configuration.getString("appName").getOrElse(ApplicationName) val appConfig = ConfigFactory.load(serviceConfFile).getConfig("db-writer") val awsConfig = new AwsConfig(appConfig) val streamName = System.getenv(EnvStreamName) match { case "processed" => awsConfig.processedStreamName case _ => awsConfig.rawStreamName } val workerId = s"${InetAddress.getLocalHost.getCanonicalHostName}:${UUID.randomUUID()}" Log.info(s"Running $ApplicationName for stream $streamName as worker $workerId") // a connection to the InfluxDB backend val influxConfig = appConfig.getConfig("influx") new Worker( // a factory for record processors new RecordProcessorFactory( awsConfig, new InfluxDataSink(InfluxConfiguration(influxConfig))), // a client library instance new KinesisClientLibConfiguration(kinesisAppName, streamName, awsConfig.awsCredentialsProvider, workerId) .withInitialPositionInStream(InitialPositionInStream.TRIM_HORIZON), new NullMetricsFactory) // TODO: check out the possibility to use CloudWatch Metrics // new CWMetricsFactory(awsConfig.awsCredentialsProvider, MetricsNamespace, MetricsBufferTime, MetricsBufferSize)) } }
Example 2
Source File: ExponentialBackOff.scala From cave with MIT License | 5 votes |
package com.cave.metrics.data import org.apache.commons.logging.LogFactory trait ExponentialBackOff { protected[this] final def loopWithBackOffOnErrorWhile(condition: => Boolean)(body: => Unit) { while (condition) { try { body backOffReset() } catch { case e: Throwable => if (ShouldLogErrors) { log.error(e.getMessage) } backOffOnError() } } } protected[this] final def retry[T](operation: => T): T = { retryUpTo(Long.MaxValue)(operation) } protected[this] final def retryUpTo[T](maxRetry: Long)(operation: => T): T = { var numRetries = 0L var result = Option.empty[T] loopWithBackOffOnErrorWhile(!result.isDefined && numRetries < maxRetry) { numRetries += 1 result = Some(operation) } result getOrElse { sys.error(s"Max number of retries reached [$maxRetry], operation aborted.") } } protected[this] final def backOffReset() { currentSleepTimeInMillis = 1L } protected[this] final def backOffOnError() { try { Thread.sleep(currentSleepTimeInMillis) } catch { case _: InterruptedException => // ignore interrupted exception } currentSleepTimeInMillis *= 2 if (currentSleepTimeInMillis > MaxBackOffTimeInMillis) { currentSleepTimeInMillis = MaxBackOffTimeInMillis } } }
Example 3
Source File: DatabaseConnection.scala From cave with MIT License | 5 votes |
package com.cave.metrics.data.postgresql import com.cave.metrics.data.AwsConfig import com.zaxxer.hikari.{HikariConfig, HikariDataSource} import org.apache.commons.logging.LogFactory import scala.slick.driver.PostgresDriver.simple._ abstract class DatabaseConnection(awsConfig: AwsConfig) { val log = LogFactory.getLog(classOf[DatabaseConnection]) val ds = new HikariDataSource(getDatabaseConfig) val db = { val database = Database.forDataSource(ds) log.debug( s""" Db connection initialized. driver: ${awsConfig.rdsJdbcDatabaseClass} user: ${awsConfig.rdsJdbcDatabaseUser} pass: [REDACTED] """.stripMargin) ds.getConnection.close() database } def closeDbConnection(): Unit = ds.close() private[this] def getDatabaseConfig: HikariConfig = { val config = new HikariConfig config.setMaximumPoolSize(awsConfig.rdsJdbcDatabasePoolSize) val className = awsConfig.rdsJdbcDatabaseClass config.setDataSourceClassName(awsConfig.rdsJdbcDatabaseClass) if (className.contains("postgres")) { config.addDataSourceProperty("serverName", awsConfig.rdsJdbcDatabaseServer) config.addDataSourceProperty("databaseName", awsConfig.rdsJdbcDatabaseName) config.addDataSourceProperty("portNumber", awsConfig.rdsJdbcDatabasePort) } else { config.addDataSourceProperty("url", awsConfig.rdsJdbcDatabaseUrl) } config.addDataSourceProperty("user", awsConfig.rdsJdbcDatabaseUser) config.addDataSourceProperty("password", awsConfig.rdsJdbcDatabasePassword) config } }
Example 4
Source File: RecordProcessor.scala From cave with MIT License | 5 votes |
package com.cave.metrics.data.kinesis import java.util.{List => JList} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import com.cave.metrics.data._ import org.apache.commons.logging.LogFactory import play.api.libs.json.Json import scala.collection.JavaConverters._ import scala.util.{Success, Try} class RecordProcessor(config: AwsConfig, sink: DataSink) extends IRecordProcessor with ExponentialBackOff { private[this] var shardId: String = _ private var nextCheckpointTimeMillis: Long = _ private[this] val log = LogFactory.getLog(classOf[RecordProcessor]) // Back off and retry settings for checkpoint override val MaxBackOffTimeInMillis = 10000L override val ShouldLogErrors: Boolean = true private val NumRetries = 10 private val CheckpointIntervalInMillis = 1000L override def initialize(shardId: String): Unit = { this.shardId = shardId } override def shutdown(check: IRecordProcessorCheckpointer, reason: ShutdownReason): Unit = { if (reason == ShutdownReason.TERMINATE) { checkpoint(check) } } override def processRecords(records: JList[Record], check: IRecordProcessorCheckpointer): Unit = { val metrics = (records.asScala map convert).filter(_.isSuccess) if (metrics.size == records.size()) { // all metrics successfully converted log.info(s"Received $metrics") sink.sendMetrics(for (Success(metric) <- metrics) yield metric) } else { log.error("Failed to parse records into Metric objects.") } if (System.currentTimeMillis() > nextCheckpointTimeMillis) { checkpoint(check) nextCheckpointTimeMillis = System.currentTimeMillis() + CheckpointIntervalInMillis } } private[this] def convert(record: Record): Try[Metric] = Try (Json.parse(new String(record.getData.array())).as[Metric]) private[this] def checkpoint(check: IRecordProcessorCheckpointer): Unit = { Try { retryUpTo(NumRetries) { check.checkpoint() } } recover { case e: Exception => log.warn(s"Failed to checkpoint shard $shardId: ${e.getMessage}") } } }
Example 5
Source File: KinesisDataSink.scala From cave with MIT License | 5 votes |
package com.cave.metrics.data.kinesis import java.nio.ByteBuffer import com.amazonaws.services.kinesis.AmazonKinesisAsyncClient import com.amazonaws.services.kinesis.model.PutRecordRequest import com.cave.metrics.data.{AwsConfig, Metric, SeqDataSink} import org.apache.commons.logging.LogFactory import play.api.libs.json.Json import scala.util.Try import scala.util.control.Exception._ import scala.util.control.NonFatal class KinesisDataSink(config: AwsConfig, streamName: String) extends SeqDataSink { val log = LogFactory.getLog(classOf[KinesisDataSink]) var client: Option[AmazonKinesisAsyncClient] = None override def sendMetric(metric: Metric): Unit = { def createRequest: PutRecordRequest = { val data = Json.toJson(metric).toString() log.info(s"Sending $data ...") val request = new PutRecordRequest request.setStreamName(streamName) request.setData(ByteBuffer.wrap(data.getBytes)) request.setPartitionKey(metric.partitionKey) request } client foreach { c => Try(c.putRecord(createRequest)) recover { case NonFatal(e) => log.warn(s"Caught exception while talking to Kinesis: $e") throw e } } } }
Example 6
Source File: Global.scala From cave with MIT License | 5 votes |
import scala.concurrent.Future import init.Init import org.apache.commons.logging.LogFactory import play.api._ import play.api.mvc._ import play.api.mvc.Results._ object Global extends GlobalSettings { override def onHandlerNotFound(request: RequestHeader) = { Future.successful(NotFound) } override def onBadRequest(request: RequestHeader, error: String) = { Future.successful(BadRequest("Bad Request: " + error)) } override def onError(request: RequestHeader, ex: Throwable) = { Logger.error(ex.toString, ex) Future.successful(InternalServerError(ex.toString)) } override def onStart(app: Application) { Init.init() } override def onStop(app: Application) { Init.shutdown() } }
Example 7
Source File: AwsWrapper.scala From cave with MIT License | 5 votes |
package init import com.amazonaws.services.sqs.AmazonSQSAsyncClient import com.amazonaws.services.sqs.model.{DeleteMessageRequest, ReceiveMessageRequest} import com.cave.metrics.data.kinesis.KinesisDataSink import com.cave.metrics.data.{Check, AwsConfig} import org.apache.commons.logging.LogFactory import play.api.libs.json.Json import scala.concurrent._ import scala.collection.JavaConverters._ object AwsWrapper { case class WorkItem(itemId: String, receiptHandle: String, check: Check) } class AwsWrapper(awsConfig: AwsConfig) { private final val Log = LogFactory.getLog(this.getClass) private final val MaxNumberOfMessages = 10 val dataSink = new KinesisDataSink(awsConfig, awsConfig.rawStreamName) // the SQS client val sqsClient = { val c = new AmazonSQSAsyncClient(awsConfig.awsCredentialsProvider) c.setEndpoint(awsConfig.awsSQSConfig.endpoint) c } val queueName = awsConfig.alarmScheduleQueueName val queueUrl = sqsClient.createQueue(queueName).getQueueUrl Log.info(s"Queue $queueName has URL $queueUrl") import AwsWrapper._ def init(): Unit = { dataSink.connect() } def shutdown(): Unit = { dataSink.disconnect() } def deleteMessage(receiptHandle: String)(implicit ec: ExecutionContext): Future[Boolean] = { val request = new DeleteMessageRequest() .withQueueUrl(queueUrl) .withReceiptHandle(receiptHandle) future { blocking { val response = sqsClient.deleteMessageAsync(request) response.get() response.isDone } } } }
Example 8
Source File: Init.scala From cave with MIT License | 5 votes |
package init import akka.actor._ import com.cave.metrics.data.AwsConfig import com.cave.metrics.data.influxdb.{InfluxClientFactory, InfluxConfiguration} import com.typesafe.config.ConfigFactory import org.apache.commons.logging.LogFactory import play.api.Play import worker.Coordinator import worker.converter.ConverterFactory import worker.web.AsyncNotificationSender object Init { private[this] val configuration = Play.current.configuration private val log = LogFactory.getLog("Init") val serviceConfFile = configuration.getString("serviceConf").getOrElse("worker.conf") val appConfig = ConfigFactory.load(serviceConfFile).getConfig("worker") // prepare AWS config val awsConfig = new AwsConfig(appConfig) // a wrapper for required AWS val awsWrapper = new AwsWrapper(awsConfig) // a connection to the InfluxDB backend val influxConfig = appConfig.getConfig("influx") val influxClientFactory = new InfluxClientFactory(InfluxConfiguration(influxConfig)) val converterFactory = new ConverterFactory(appConfig.getConfig("converters")) val sender = new AsyncNotificationSender(converterFactory) val system = ActorSystem("CaveWorker") val coordinator = system.actorOf(Props(new Coordinator(awsWrapper)), "coordinator") def init() { log.info("Init started...") awsWrapper.init() log.info("Init completed.") } def shutdown() { log.info("Shutdown started...") awsWrapper.shutdown() influxClientFactory.close() system.shutdown() log.info("Shutdown completed.") } }
Example 9
Source File: Global.scala From cave with MIT License | 5 votes |
import scala.concurrent.Future import init.Init import org.apache.commons.logging.LogFactory import play.api._ import play.api.mvc._ import play.api.mvc.Results._ object Global extends GlobalSettings { private[this] final val Log = LogFactory.getLog(this.getClass) override def onHandlerNotFound(request: RequestHeader) = { Future.successful(NotFound) } override def onBadRequest(request: RequestHeader, error: String) = { Future.successful(BadRequest("Bad Request: " + error)) } override def onError(request: RequestHeader, ex: Throwable) = { Logger.error(ex.toString, ex) Future.successful(InternalServerError(ex.toString)) } override def onStart(app: Application) { Init.init() } override def onStop(app: Application) { Init.shutdown() } }
Example 10
Source File: AkkaConfig.scala From cave with MIT License | 5 votes |
package init import com.typesafe.config.{ConfigValueFactory, ConfigFactory} import org.apache.commons.logging.LogFactory import scala.collection.JavaConversions._ class AkkaConfig(hostname: Option[String], awsWrapper: AwsWrapper) { private final val Log = LogFactory.getLog(this.getClass) private final val AkkaPort = 2551 private final val Localhost = "localhost" private val (host: String, siblings: List[String]) = hostname match { case Some(name) if name.length > 0 => (name, awsWrapper.getNodes("scheduler")) case _ => (Localhost, List(Localhost)) } private val seeds = siblings map (ip => s"akka.tcp://scheduler@$ip:$AkkaPort") Log.warn(s"Seeds: ${seeds.mkString}") private val overrideConfig = ConfigFactory.empty() .withValue("akka.remote.netty.tcp.hostname", ConfigValueFactory.fromAnyRef(host)) .withValue("akka.remote.netty.tcp.port", ConfigValueFactory.fromAnyRef(AkkaPort)) .withValue("akka.cluster.seed-nodes", ConfigValueFactory.fromIterable(seeds)) private val defaults = ConfigFactory.load() val config = overrideConfig withFallback defaults }
Example 11
Source File: Global.scala From cave with MIT License | 5 votes |
import filters.HttpsFilter import init.Init import org.apache.commons.logging.LogFactory import play.api._ import play.api.mvc.Results._ import play.api.mvc._ import scala.concurrent.Future object Global extends WithFilters(HttpsFilter) with GlobalSettings { private[this] final val Log = LogFactory.getLog(this.getClass) override def onHandlerNotFound(request: RequestHeader) = { Future.successful(NotFound) } override def onBadRequest(request: RequestHeader, error: String) = { Future.successful(BadRequest("Bad Request: " + error)) } override def onError(request: RequestHeader, ex: Throwable) = { Logger.error(ex.toString, ex) Future.successful(InternalServerError(ex.toString)) } override def onStart(app: Application) { Init.init() } override def onStop(app: Application) { Init.shutdown() } }
Example 12
Source File: AwsWrapper.scala From cave with MIT License | 5 votes |
package init import com.amazonaws.services.sns.AmazonSNSAsyncClient import com.amazonaws.services.sns.model.{PublishRequest, PublishResult} import com.cave.metrics.data.Operation.Operation import com.cave.metrics.data._ import com.cave.metrics.data.kinesis.KinesisDataSink import com.cave.metrics.data.postgresql.PostgresDataManagerImpl import org.apache.commons.logging.LogFactory import play.api.libs.concurrent.Execution.Implicits._ import play.api.libs.json.Json import scala.concurrent._ class AwsWrapper(awsConfig: AwsConfig) { private final val log = LogFactory.getLog(this.getClass) val dataSink = new KinesisDataSink(awsConfig, awsConfig.rawStreamName) // a connection to the Postgres backend val dataManager: DataManager = new PostgresDataManagerImpl(awsConfig) // the SNS client val snsClient = { val c = new AmazonSNSAsyncClient(awsConfig.awsCredentialsProvider) c.setEndpoint(awsConfig.awsSNSConfig.endpoint) c } val topicArn = getTopicArn def getTopicArn = { val topicName = awsConfig.configurationChangesTopicName val arn = snsClient.createTopic(topicName).getTopicArn log.info(s"Topic $topicName has ARN $arn") arn } def init() = { dataSink.connect() } def shutdown() = { dataSink.disconnect() } def createOrganizationNotification(org: Organization) = sendOrganization(Operation.Create, org.name, "") def updateOrganizationNotification(org: Organization) = sendOrganization(Operation.Update, org.name, org.notificationUrl) def deleteOrganizationNotification(orgName: String) = sendOrganization(Operation.Delete, orgName, "") def createAlertNotification(schedule: Schedule) = sendAlert(Operation.Create, schedule) def updateAlertNotification(schedule: Schedule) = sendAlert(Operation.Update, schedule) def deleteAlertNotification(scheduleId: String, orgName: String) = sendNotification(Update(Entity.Alert, Operation.Delete, scheduleId, orgName)) private[init] def sendOrganization(op: Operation, orgName: String, extra: String) = sendNotification(Update(Entity.Organization, op, orgName, extra)) private[init] def sendAlert(op: Operation, schedule: Schedule) = { sendNotification(Update(Entity.Alert, op, schedule.alert.id.get, Json.stringify(Json.toJson(schedule)))) } private[init] def sendNotification(update: Update): Future[Unit] = { val message = Json.stringify(Json.toJson(update)) val request = new PublishRequest(topicArn, message) val response = snsClient publishAsync request future { blocking { val result: PublishResult = response.get() log.info(s"Successfully posted the notification '$message', messageId: ${result.getMessageId}") } } } }
Example 13
Source File: Init.scala From cave with MIT License | 5 votes |
package init import java.util.concurrent.TimeUnit import com.cave.metrics.data.influxdb.{InfluxClientFactory, InfluxConfiguration} import com.cave.metrics.data.metrics.InternalReporter import com.cave.metrics.data.{AlertManager, AwsConfig, Metric, PasswordHelper} import com.codahale.metrics.MetricRegistry import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet, ThreadStatesGaugeSet} import com.typesafe.config.ConfigFactory import org.apache.commons.logging.LogFactory import play.api.Play object Init { val metricRegistry = new MetricRegistry private val log = LogFactory.getLog("Init") private val InternalTags = Map(Metric.Organization -> Metric.Internal) private[this] val configuration = Play.current.configuration val baseUrl = configuration.getString("baseUrl").getOrElse("https://api.cavellc.io") val maxTokens = configuration.getInt("maxTokens").getOrElse(3) val serviceConfFile = configuration.getString("serviceConf").getOrElse("api-service.conf") val appConfig = ConfigFactory.load(serviceConfFile).getConfig("api-service") // prepare AWS config and Kinesis data sink val awsConfig = new AwsConfig(appConfig) // a wrapper for required AWS val awsWrapper = new AwsWrapper(awsConfig) // a connection to the InfluxDB backend val influxConfig = appConfig.getConfig("influx") val influxClientFactory = new InfluxClientFactory(InfluxConfiguration(influxConfig)) val alertManager = new AlertManager(awsWrapper.dataManager, influxClientFactory) val mailService = new MailService val passwordHelper = new PasswordHelper def init() { awsWrapper.init() log.warn("Init.init()") val reporter = InternalReporter(registry = metricRegistry) { metrics => metrics foreach(metric => awsWrapper.dataSink.sendMetric(Metric(metric.name, metric.timestamp, metric.value, InternalTags ++ metric.tags))) } reporter.start(1, TimeUnit.MINUTES) metricRegistry.register(MetricRegistry.name("jvm", "gc"), new GarbageCollectorMetricSet()) metricRegistry.register(MetricRegistry.name("jvm", "memory"), new MemoryUsageGaugeSet()) metricRegistry.register(MetricRegistry.name("jvm", "thread-states"), new ThreadStatesGaugeSet()) } def shutdown() { awsWrapper.shutdown() influxClientFactory.close() log.warn("Init.shutdown()") } }
Example 14
Source File: Global.scala From cave with MIT License | 5 votes |
import filters.HttpsAndWwwRedirectForElbFilter import init.Init import org.apache.commons.logging.LogFactory import play.api._ import play.api.mvc.Results._ import play.api.mvc._ import scala.concurrent.Future object Global extends WithFilters(HttpsAndWwwRedirectForElbFilter) with GlobalSettings { private[this] final val Log = LogFactory.getLog(this.getClass) override def onHandlerNotFound(request: RequestHeader) = { Future.successful(NotFound(views.html.errorpages.pageNotFound(request.path))) } override def onBadRequest(request: RequestHeader, error: String) = { Future.successful(BadRequest("Bad Request: " + error)) } override def onError(request: RequestHeader, ex: Throwable) = { Logger.error(ex.toString, ex) Future.successful(InternalServerError(views.html.errorpages.errorPage(ex.getMessage))) } override def onStart(app: Application) { Init.init } override def onStop(app: Application) { Init.shutdown() } }
Example 15
Source File: Global.scala From cave with MIT License | 5 votes |
import scala.concurrent.Future import scala.util.{Failure, Success} import init.Init import org.apache.commons.logging.LogFactory import play.api._ import play.api.mvc._ import play.api.mvc.Results._ object Global extends GlobalSettings { private[this] final val Log = LogFactory.getLog(this.getClass) override def onHandlerNotFound(request: RequestHeader) = { Future.successful(NotFound) } override def onBadRequest(request: RequestHeader, error: String) = { Future.successful(BadRequest("Bad Request: " + error)) } override def onError(request: RequestHeader, ex: Throwable) = { Logger.error(ex.toString, ex) Future.successful(InternalServerError(ex.toString)) } override def onStart(app: Application) { Init.start() } override def onStop(app: Application) { Init.shutdown() } }
Example 16
Source File: SpotlightLog.scala From dbpedia-spotlight-model with Apache License 2.0 | 5 votes |
package org.dbpedia.spotlight.log import org.apache.commons.logging.{Log, LogFactory} import scala.collection.mutable trait SpotlightLog[T] { def _debug(c:Class[_], msg: T, args: Any*) def _info(c:Class[_], msg: T, args: Any*) def _error(c:Class[_], msg: T, args: Any*) def _fatal(c:Class[_], msg: T, args: Any*) def _trace(c:Class[_], msg: T, args: Any*) def _warn(c:Class[_], msg: T, args: Any*) } object SpotlightLog { def debug[T](c:Class[_], msg: T, args: Any*)(implicit instance: SpotlightLog[T]) = instance._debug(c, msg, args: _*) def info[T](c:Class[_], msg: T, args: Any*)(implicit instance: SpotlightLog[T]) = instance._info(c, msg, args: _*) def error[T](c:Class[_], msg: T, args: Any*)(implicit instance: SpotlightLog[T]) = instance._error(c, msg, args: _*) def fatal[T](c:Class[_], msg: T, args: Any*)(implicit instance: SpotlightLog[T]) = instance._fatal(c, msg, args: _*) def trace[T](c:Class[_], msg: T, args: Any*)(implicit instance: SpotlightLog[T]) = instance._trace(c, msg, args: _*) def warn[T](c:Class[_], msg: T, args: Any*)(implicit instance: SpotlightLog[T]) = instance._warn(c, msg, args: _*) implicit object StringSpotlightLog extends SpotlightLog[String] { val loggers = new mutable.HashMap[Class[_], Log]() def _debug(c:Class[_], msg: String, args: Any*) = { val log = loggers.getOrElseUpdate(c, LogFactory.getLog(c)) if (log.isDebugEnabled) { if(args.size == 0) log.debug(msg) else log.debug(msg.format(args: _*)) } } def _info(c:Class[_], msg: String, args: Any*) = { val log = loggers.getOrElseUpdate(c, LogFactory.getLog(c)) if(log.isInfoEnabled) { if(args.size == 0) log.info(msg) else log.info(msg.format(args: _*)) } } def _error(c:Class[_], msg: String, args: Any*) = { val log = loggers.getOrElseUpdate(c, LogFactory.getLog(c)) if(log.isErrorEnabled) { if(args.size == 0) log.error(msg) else log.error(msg.format(args: _*)) } } def _fatal(c:Class[_], msg: String, args: Any*) = { val log = loggers.getOrElseUpdate(c, LogFactory.getLog(c)) if(log.isFatalEnabled) { if(args.size == 0) log.fatal(msg) else log.fatal(msg.format(args: _*)) } } def _trace(c:Class[_], msg: String, args: Any*) = { val log = loggers.getOrElseUpdate(c, LogFactory.getLog(c)) if(log.isTraceEnabled) { if(args.size == 0) log.trace(msg) else log.trace(msg.format(args: _*)) } } def _warn(c:Class[_], msg: String, args: Any*) = { val log = loggers.getOrElseUpdate(c, LogFactory.getLog(c)) if(log.isWarnEnabled) { if(args.size == 0) log.warn(msg) else log.warn(msg.format(args: _*)) } } } }
Example 17
Source File: LocalMemoryDataBlock.scala From sona with Apache License 2.0 | 5 votes |
package com.tencent.angel.sona.data import java.io.IOException import java.util import java.util.Collections import com.tencent.angel.ml.math2.utils.{DataBlock, LabeledData} import org.apache.commons.logging.{Log, LogFactory} import org.ehcache.sizeof.SizeOf class LocalMemoryDataBlock(initSize: Int, maxUseMemroy: Long) extends DataBlock[LabeledData] { private val LOG: Log = LogFactory.getLog(classOf[LocalMemoryDataBlock]) private var estimateSampleNumber: Int = 100 val initCapacity = if (initSize > 0) { estimateSampleNumber = initSize initSize } else { estimateSampleNumber } private val vList = new util.ArrayList[LabeledData]() private var isFull: Boolean = false @throws[IOException] override def read(): LabeledData = { if (readIndex < writeIndex) { val value = vList.get(readIndex) readIndex += 1 value } else { null.asInstanceOf[LabeledData] } } @throws[IOException] override protected def hasNext: Boolean = readIndex < writeIndex @throws[IOException] override def get(index: Int): LabeledData = { if (index < 0 || index >= writeIndex) { throw new IOException("index not in range[0," + writeIndex + ")") } vList.get(index) } @throws[IOException] override def put(value: LabeledData): Unit = { if (writeIndex < estimateSampleNumber) { vList.add(value) writeIndex += 1 if (writeIndex == estimateSampleNumber && !isFull) { estimateAndResizeVList() } } else { LOG.info("Over maxUseMemroy, No value added!") } } override def resetReadIndex(): Unit = { readIndex = 0 } override def clean(): Unit = { readIndex = 0 writeIndex = 0 vList.clear() } override def shuffle(): Unit = Collections.shuffle(vList) override def flush(): Unit = {} override def slice(startIndex: Int, length: Int): DataBlock[LabeledData] = ??? private def estimateAndResizeVList(): Unit = { val avgDataItemSize = (SizeOf.newInstance().deepSizeOf(vList) + vList.size - 1) / vList.size val maxStoreNum = (maxUseMemroy / avgDataItemSize).toInt val capacity = if (maxStoreNum < 2 * vList.size) { isFull = true maxStoreNum } else { 2 * vList.size } estimateSampleNumber = (0.8 * capacity).toInt vList.ensureCapacity(capacity) LOG.debug("estimate sample number=" + vList.size + ", avgDataItemSize=" + avgDataItemSize + ", maxStoreNum=" + maxStoreNum + ", maxUseMemroy=" + maxUseMemroy) } }
Example 18
Source File: IndexedBinaryBlockReader.scala From hail with MIT License | 5 votes |
package is.hail.io import is.hail.annotations.RegionValueBuilder import is.hail.io.fs.{HadoopFS, WrappedSeekableDataInputStream} import org.apache.commons.logging.{Log, LogFactory} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.LongWritable import org.apache.hadoop.mapred._ abstract class KeySerializedValueRecord[K] extends Serializable { var input: Array[Byte] = _ var key: K = _ def setSerializedValue(arr: Array[Byte]) { this.input = arr } def getValue(rvb: RegionValueBuilder, includeGT: Boolean): Unit def setKey(k: K) { this.key = k } def getKey: K = key } abstract class IndexedBinaryBlockReader[T](job: Configuration, split: FileSplit) extends RecordReader[LongWritable, T] { val LOG: Log = LogFactory.getLog(classOf[IndexedBinaryBlockReader[T]].getName) val partitionStart: Long = split.getStart var pos: Long = partitionStart val end: Long = partitionStart + split.getLength val bfis = openFile() def openFile(): HadoopFSDataBinaryReader = { val file: Path = split.getPath val fs: FileSystem = file.getFileSystem(job) val is = fs.open(file) new HadoopFSDataBinaryReader( new WrappedSeekableDataInputStream( HadoopFS.toSeekableInputStream(is))) } def createKey(): LongWritable = new LongWritable() def createValue(): T def getPos: Long = pos def getProgress: Float = { if (partitionStart == end) 0.0f else Math.min(1.0f, (pos - partitionStart) / (end - partitionStart).toFloat) } def close() = bfis.close() }
Example 19
Source File: RangerAdminClientImpl.scala From spark-ranger with Apache License 2.0 | 5 votes |
package org.apache.ranger.services.spark import java.nio.file.{Files, FileSystems} import java.util import com.google.gson.GsonBuilder import org.apache.commons.logging.{Log, LogFactory} import org.apache.ranger.admin.client.RangerAdminRESTClient import org.apache.ranger.plugin.util.{GrantRevokeRequest, ServicePolicies, ServiceTags} class RangerAdminClientImpl extends RangerAdminRESTClient { private val LOG: Log = LogFactory.getLog(classOf[RangerAdminClientImpl]) private val cacheFilename = "sparkSql_hive_jenkins.json" private val gson = new GsonBuilder().setDateFormat("yyyyMMdd-HH:mm:ss.SSS-Z").setPrettyPrinting().create override def init(serviceName: String, appId: String, configPropertyPrefix: String): Unit = {} override def getServicePoliciesIfUpdated( lastKnownVersion: Long, lastActivationTimeInMillis: Long): ServicePolicies = { val basedir = this.getClass.getProtectionDomain.getCodeSource.getLocation.getPath val cachePath = FileSystems.getDefault.getPath(basedir, cacheFilename) LOG.info("Reading policies from " + cachePath) val bytes = Files.readAllBytes(cachePath) gson.fromJson(new String(bytes), classOf[ServicePolicies]) } override def grantAccess(request: GrantRevokeRequest): Unit = {} override def revokeAccess(request: GrantRevokeRequest): Unit = {} override def getServiceTagsIfUpdated( lastKnownVersion: Long, lastActivationTimeInMillis: Long): ServiceTags = null override def getTagTypes(tagTypePattern: String): util.List[String] = null }
Example 20
Source File: RangerSparkPlugin.scala From spark-ranger with Apache License 2.0 | 5 votes |
package org.apache.ranger.authorization.spark.authorizer import java.io.{File, IOException} import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveAuthzSessionContext import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveAuthzSessionContext.CLIENT_TYPE import org.apache.ranger.authorization.hadoop.config.RangerConfiguration import org.apache.ranger.plugin.service.RangerBasePlugin class RangerSparkPlugin private extends RangerBasePlugin("spark", "sparkSql") { import RangerSparkPlugin._ private val LOG = LogFactory.getLog(classOf[RangerSparkPlugin]) lazy val fsScheme: Array[String] = RangerConfiguration.getInstance() .get("ranger.plugin.spark.urlauth.filesystem.schemes", "hdfs:,file:") .split(",") .map(_.trim) override def init(): Unit = { super.init() val cacheDir = new File(rangerConf.get("ranger.plugin.spark.policy.cache.dir")) if (cacheDir.exists() && (!cacheDir.isDirectory || !cacheDir.canRead || !cacheDir.canWrite)) { throw new IOException("Policy cache directory already exists at" + cacheDir.getAbsolutePath + ", but it is unavailable") } if (!cacheDir.exists() && !cacheDir.mkdirs()) { throw new IOException("Unable to create ranger policy cache directory at" + cacheDir.getAbsolutePath) } LOG.info("Policy cache directory successfully set to " + cacheDir.getAbsolutePath) } } object RangerSparkPlugin { private val rangerConf: RangerConfiguration = RangerConfiguration.getInstance val showColumnsOption: String = rangerConf.get( "xasecure.spark.describetable.showcolumns.authorization.option", "NONE") def build(): Builder = new Builder class Builder { @volatile private var sparkPlugin: RangerSparkPlugin = _ def getOrCreate(): RangerSparkPlugin = RangerSparkPlugin.synchronized { if (sparkPlugin == null) { sparkPlugin = new RangerSparkPlugin sparkPlugin.init() sparkPlugin } else { sparkPlugin } } } }
Example 21
Source File: ParentTest.scala From Soteria with MIT License | 5 votes |
package com.leobenkel.soteria import org.apache.commons.logging.{Log, LogFactory} import org.scalactic.source.Position import org.scalatest.{FunSuite, Tag} trait ParentTest extends FunSuite { lazy val log: Log = LogFactory.getLog(this.getClass) protected def assertEquals[T]( expected: T, result: T )( implicit pos: Position ): Unit = { assertResult(expected)(result) () } override protected def test( testName: String, testTags: Tag* )( testFun: => Any )( implicit pos: Position ): Unit = { super.test(testName, testTags: _*) { log.debug(s">>> Starting - $testName") testFun } } def time[R](block: => R): (R, Long) = { val t0 = System.nanoTime() val result = block val t1 = System.nanoTime() val time_ns: Long = t1 - t0 (result, time_ns) } }
Example 22
Source File: ShapeInputFormat.scala From magellan with Apache License 2.0 | 5 votes |
package magellan.mapreduce import com.google.common.base.Stopwatch import magellan.io.{ShapeKey, ShapeWritable} import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.{LocatedFileStatus, Path} import org.apache.hadoop.mapreduce.lib.input._ import org.apache.hadoop.mapreduce.{InputSplit, JobContext, TaskAttemptContext} import scala.collection.JavaConversions._ import scala.collection.mutable.ListBuffer private[magellan] class ShapeInputFormat extends FileInputFormat[ShapeKey, ShapeWritable] { private val log = LogFactory.getLog(classOf[ShapeInputFormat]) override def createRecordReader(inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext) = { new ShapefileReader } override def isSplitable(context: JobContext, filename: Path): Boolean = true override def getSplits(job: JobContext): java.util.List[InputSplit] = { val splitInfos = SplitInfos.SPLIT_INFO_MAP.get() computeSplits(job, splitInfos) } private def computeSplits( job: JobContext, splitInfos: scala.collection.Map[String, Array[Long]]) = { val sw = new Stopwatch().start val splits = ListBuffer[InputSplit]() val files = listStatus(job) for (file <- files) { val path = file.getPath val length = file.getLen val blkLocations = if (file.isInstanceOf[LocatedFileStatus]) { file.asInstanceOf[LocatedFileStatus].getBlockLocations } else { val fs = path.getFileSystem(job.getConfiguration) fs.getFileBlockLocations(file, 0, length) } val key = path.getName.split("\\.shp$")(0) if (splitInfos == null || !splitInfos.containsKey(key)) { val blkIndex = getBlockIndex(blkLocations, 0) splits.+= (makeSplit(path, 0, length, blkLocations(blkIndex).getHosts, blkLocations(blkIndex).getCachedHosts)) } else { val s = splitInfos(key).toSeq val start = s val end = s.drop(1) ++ Seq(length) start.zip(end).foreach { case (startOffset: Long, endOffset: Long) => val blkIndex = getBlockIndex(blkLocations, startOffset) splits.+=(makeSplit(path, startOffset, endOffset - startOffset, blkLocations(blkIndex).getHosts, blkLocations(blkIndex).getCachedHosts)) } } } sw.stop if (log.isDebugEnabled) { log.debug("Total # of splits generated by getSplits: " + splits.size + ", TimeTaken: " + sw.elapsedMillis) } splits } } object SplitInfos { // TODO: Can we get rid of this hack to pass split calculation to the Shapefile Reader? val SPLIT_INFO_MAP = new ThreadLocal[scala.collection.Map[String, Array[Long]]] }
Example 23
Source File: ExcelRelation.scala From spark-hadoopoffice-ds with Apache License 2.0 | 5 votes |
package org.zuinnote.spark.office.excel import scala.collection.JavaConversions._ import org.apache.spark.sql.sources.{ BaseRelation, TableScan } import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.apache.spark.sql.SQLContext import org.apache.spark.sql._ import org.apache.spark.rdd.RDD import org.apache.hadoop.conf._ import org.apache.hadoop.mapreduce._ import org.apache.commons.logging.LogFactory import org.apache.commons.logging.Log import org.zuinnote.hadoop.office.format.common.dao._ import org.zuinnote.hadoop.office.format.mapreduce._ import org.zuinnote.spark.office.excel.util.ExcelFile override def buildScan: RDD[Row] = { // read ExcelRows val excelRowsRDD = ExcelFile.load(sqlContext, location, hadoopParams) // map to schema val schemaFields = schema.fields excelRowsRDD.flatMap(excelKeyValueTuple => { // map the Excel row data structure to a Spark SQL schema val rowArray = new Array[Any](excelKeyValueTuple._2.get.length) var i = 0; for (x <- excelKeyValueTuple._2.get) { // parse through the SpreadSheetCellDAO val spreadSheetCellDAOStructArray = new Array[String](schemaFields.length) val currentSpreadSheetCellDAO: Array[SpreadSheetCellDAO] = excelKeyValueTuple._2.get.asInstanceOf[Array[SpreadSheetCellDAO]] spreadSheetCellDAOStructArray(0) = currentSpreadSheetCellDAO(i).getFormattedValue spreadSheetCellDAOStructArray(1) = currentSpreadSheetCellDAO(i).getComment spreadSheetCellDAOStructArray(2) = currentSpreadSheetCellDAO(i).getFormula spreadSheetCellDAOStructArray(3) = currentSpreadSheetCellDAO(i).getAddress spreadSheetCellDAOStructArray(4) = currentSpreadSheetCellDAO(i).getSheetName // add row representing one Excel row rowArray(i) = spreadSheetCellDAOStructArray i += 1 } Some(Row.fromSeq(rowArray)) }) } }
Example 24
Source File: HadoopFileExcelReader.scala From spark-hadoopoffice-ds with Apache License 2.0 | 5 votes |
package org.zuinnote.spark.office.excel import java.io.Closeable import java.net.URI import org.apache.spark.sql.execution.datasources._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.ArrayWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{ FileSplit, LineRecordReader } import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.sql.execution.datasources.RecordReaderIterator import org.zuinnote.hadoop.office.format.mapreduce.ExcelFileInputFormat import org.zuinnote.hadoop.office.format.mapreduce.ExcelRecordReader import org.apache.commons.logging.LogFactory import org.apache.commons.logging.Log class HadoopFileExcelReader( file: PartitionedFile, conf: Configuration) extends Iterator[ArrayWritable] with Closeable { val LOG = LogFactory.getLog(classOf[HadoopFileExcelReader]) private var reader: RecordReader[Text, ArrayWritable] = null private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), file.start, file.length, Array.empty) // todo: implement locality (replace Array.empty with the locations) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) val inputFormat = new ExcelFileInputFormat() reader = inputFormat.createRecordReader(fileSplit, hadoopAttemptContext) reader.initialize(fileSplit, hadoopAttemptContext) new RecordReaderIterator(reader) } def getReader: RecordReader[Text, ArrayWritable] = reader override def hasNext: Boolean = iterator.hasNext override def next(): ArrayWritable = iterator.next() override def close(): Unit = { if (reader != null) { reader.close() } } }
Example 25
Source File: ExcelOutputWriter.scala From spark-hadoopoffice-ds with Apache License 2.0 | 5 votes |
package org.zuinnote.spark.office.excel import java.math.BigDecimal import java.sql.Date import java.sql.Timestamp import java.text.DateFormat import java.text.SimpleDateFormat import java.util.Calendar import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.ArrayWritable import org.apache.hadoop.mapreduce.RecordWriter import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.{ CatalystTypeConverters, InternalRow } import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ import org.zuinnote.hadoop.office.format.common.dao.SpreadSheetCellDAO import org.zuinnote.hadoop.office.format.common.HadoopOfficeWriteConfiguration import org.zuinnote.hadoop.office.format.common.util.msexcel.MSExcelUtil import org.zuinnote.hadoop.office.format.mapreduce._ import org.apache.commons.logging.LogFactory import org.apache.commons.logging.Log import org.zuinnote.hadoop.office.format.common.HadoopOfficeWriteConfiguration import java.util.Locale import java.text.DecimalFormat import org.zuinnote.hadoop.office.format.common.converter.ExcelConverterSimpleSpreadSheetCellDAO import java.text.NumberFormat // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[excel] class ExcelOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, options: Map[String, String]) extends OutputWriter { def write(row: Row): Unit = { // check useHeader if (useHeader) { val headers = row.schema.fieldNames var i = 0 for (x <- headers) { val headerColumnSCD = new SpreadSheetCellDAO(x, "", "", MSExcelUtil.getCellAddressA1Format(currentRowNum, i), defaultSheetName) recordWriter.write(NullWritable.get(), headerColumnSCD) i += 1 } currentRowNum += 1 useHeader = false } // for each value in the row if (row.size>0) { var currentColumnNum = 0; val simpleObject = new Array[AnyRef](row.size) for (i <- 0 to row.size - 1) { // for each element of the row val obj = row.get(i) if ((obj.isInstanceOf[Seq[String]]) && (obj.asInstanceOf[Seq[String]].length==5)) { val formattedValue = obj.asInstanceOf[Seq[String]](0) val comment = obj.asInstanceOf[Seq[String]](1) val formula = obj.asInstanceOf[Seq[String]](2) val address = obj.asInstanceOf[Seq[String]](3) val sheetName = obj.asInstanceOf[Seq[String]](4) simpleObject(i) = new SpreadSheetCellDAO(formattedValue,comment,formula,address,sheetName) } else { simpleObject(i)=obj.asInstanceOf[AnyRef] } } // convert row to spreadsheetcellDAO val spreadSheetCellDAORow = simpleConverter.getSpreadSheetCellDAOfromSimpleDataType(simpleObject, defaultSheetName, currentRowNum) // write it for (x<- spreadSheetCellDAORow) { recordWriter.write(NullWritable.get(), x) } } currentRowNum += 1 } override def close(): Unit = { recordWriter.close(context) currentRowNum = 0; } }
Example 26
Source File: SapThriftServer.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.thriftserver import org.apache.commons.logging.LogFactory import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.sap.thriftserver.SapSQLEnv import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.hive.service.server.HiveServerServerOptionsProcessor object SapThriftServer extends Logging { var LOG = LogFactory.getLog(classOf[SapThriftServer]) def main(args: Array[String]) { val optionsProcessor = new HiveServerServerOptionsProcessor("SapThriftServer") if (!optionsProcessor.process(args)) { System.exit(-1) } logInfo("Starting SparkContext") SapSQLEnv.init() org.apache.spark.util.ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() uiTab.foreach(_.detach()) } try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(SparkSQLEnv.hiveContext.hiveconf) server.start() logInfo("SapThriftServer started") listener = new HiveThriftServer2Listener(server, SparkSQLEnv.hiveContext.conf) SparkSQLEnv.sparkContext.addSparkListener(listener) uiTab = if (SparkSQLEnv.sparkContext.getConf.getBoolean("spark.ui.enabled", true)) { Some(new ThriftServerTab(SparkSQLEnv.sparkContext)) } else { None } } catch { case e: Exception => logError("Error starting SapThriftServer", e) System.exit(-1) } } } private[hive] class SapThriftServer(val hiveContext: HiveContext) extends Logging{ def start: Unit = { logInfo("ThriftServer with SapSQLContext") logInfo("Starting SparkContext") HiveThriftServer2.startWithContext(hiveContext) } }
Example 27
Source File: RandomSearch.scala From automl with Apache License 2.0 | 5 votes |
package com.tencent.angel.spark.automl.tuner.acquisition.optimizer import com.tencent.angel.spark.automl.tuner.TunerParam import com.tencent.angel.spark.automl.tuner.acquisition.Acquisition import com.tencent.angel.spark.automl.tuner.config.{Configuration, ConfigurationSpace} import org.apache.commons.logging.{Log, LogFactory} import scala.util.Random class RandomSearch( override val acqFunc: Acquisition, override val configSpace: ConfigurationSpace, seed: Int = 100) extends AcqOptimizer(acqFunc, configSpace) { val LOG: Log = LogFactory.getLog(classOf[RandomSearch]) val rd = new Random(seed) override def maximize(numPoints: Int, sorted: Boolean = true): Array[(Double, Configuration)] = { //println(s"maximize RandomSearch") val configs: Array[Configuration] = configSpace.sample(TunerParam.sampleSize) if (configs.isEmpty) { Array[(Double, Configuration)]() } else { //configs.foreach { config => // println(s"sample a configuration: ${config.getVector.toArray.mkString(",")}") //} val retConfigs = if (sorted) { configs.map { config => (acqFunc.compute(config.getVector)._1, config) }.sortWith(_._1 > _._1).take(numPoints) } else { rd.shuffle(configs.map { config => (acqFunc.compute(config.getVector)._1, config) }.toTraversable).take(numPoints).toArray } retConfigs } } override def maximize: (Double, Configuration) = { maximize(1, true).head } }
Example 28
package com.tencent.angel.spark.automl.tuner.acquisition import com.tencent.angel.spark.automl.tuner.surrogate.Surrogate import org.apache.commons.logging.{Log, LogFactory} import org.apache.commons.math3.distribution.NormalDistribution import org.apache.spark.ml.linalg.{Vector, Vectors} class EI( override val surrogate: Surrogate, val par: Double) extends Acquisition(surrogate) { val LOG: Log = LogFactory.getLog(classOf[Surrogate]) override def compute(X: Vector, derivative: Boolean = false): (Double, Vector) = { val pred = surrogate.predict(X) // (mean, variance) // Use the best seen observation as incumbent val eta: Double = surrogate.curBest._2 //println(s"best seen result: $eta") val m: Double = pred._1 val s: Double = Math.sqrt(pred._2) //println(s"${X.toArray.mkString("(", ",", ")")}: mean[$m], variance[$s]") if (s == 0) { // if std is zero, we have observed x on all instances // using a RF, std should be never exactly 0.0 (0.0, Vectors.dense(new Array[Double](X.size))) } else { val z = (pred._1 - eta - par) / s val norm: NormalDistribution = new NormalDistribution val cdf: Double = norm.cumulativeProbability(z) val pdf: Double = norm.density(z) val ei = s * (z * cdf + pdf) //println(s"EI of ${X.toArray.mkString("(", ",", ")")}: $ei, cur best: $eta, z: $z, cdf: $cdf, pdf: $pdf") (ei, Vectors.dense(new Array[Double](X.size))) } } }
Example 29
package com.tencent.angel.spark.automl.tuner.acquisition import com.tencent.angel.spark.automl.tuner.surrogate.Surrogate import org.apache.commons.logging.{Log, LogFactory} import org.apache.spark.ml.linalg.{Vector, Vectors} class UCB( override val surrogate: Surrogate, val beta: Double = 100) extends Acquisition(surrogate) { val LOG: Log = LogFactory.getLog(classOf[Surrogate]) override def compute(X: Vector, derivative: Boolean = false): (Double, Vector) = { val pred = surrogate.predict(X) // (mean, variance) val m: Double = pred._1 val s: Double = Math.sqrt(pred._2) if (s == 0) { // if std is zero, we have observed x on all instances // using a RF, std should be never exactly 0.0 (0.0, Vectors.dense(new Array[Double](X.size))) } else { val ucb = m + beta * s (ucb, Vectors.dense(new Array[Double](X.size))) } } }
Example 30
Source File: Surrogate.scala From automl with Apache License 2.0 | 5 votes |
package com.tencent.angel.spark.automl.tuner.surrogate import com.tencent.angel.spark.automl.tuner.config.ConfigurationSpace import org.apache.commons.logging.{Log, LogFactory} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.types.{DataTypes, StructField, StructType} import scala.collection.mutable.ArrayBuffer def predict(X: Vector): (Double, Double) def stop(): Unit def curBest: (Vector, Double) = { if (minimize) curMin else curMax } def curMin: (Vector, Double) = { if (preY.isEmpty) (null, Double.MaxValue) else { val maxIdx: Int = preY.zipWithIndex.max._2 (preX(maxIdx), -preY(maxIdx)) } } def curMax: (Vector, Double) = { if (preY.isEmpty) (null, Double.MinValue) else { val maxIdx: Int = preY.zipWithIndex.max._2 (preX(maxIdx), preY(maxIdx)) } } }