org.reflections.Reflections Scala Examples

The following examples show how to use org.reflections.Reflections. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: CatalogScanner.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.refl

import java.io.File
import java.net.{URL, URLClassLoader}

import ai.deepsense.commons.utils.Logging
import ai.deepsense.deeplang.catalogs.SortPriority
import ai.deepsense.deeplang.catalogs.spi.{CatalogRegistrant, CatalogRegistrar}
import ai.deepsense.deeplang.{DOperation, DOperationCategories, TypeUtils}
import org.reflections.Reflections
import org.reflections.util.ConfigurationBuilder

import scala.collection.JavaConversions._


  override def register(registrar: CatalogRegistrar): Unit = {
    logger.info(
      s"Scanning registrables. Following jars will be scanned: ${jarsUrls.mkString(";")}.")
    val scanned = scanForRegistrables().iterator
    val priorities = SortPriority.sdkInSequence
    for {
      (registrable, priority) <- scanned.zip(priorities)
    } {
      logger.debug(s"Trying to register class $registrable")
      registrable match {
        case DOperationMatcher(doperation) => registerDOperation(registrar, doperation, priority)
        case other => logger.warn(s"Only DOperation can be `@Register`ed. '$other' not supported.")
      }
    }
  }

  private def scanForRegistrables(): Set[Class[_]] = {

    val urls = thisJarURLOpt ++ jarsUrls

    if (urls.nonEmpty) {

      val configBuilder = ConfigurationBuilder
        .build(urls.toSeq: _*)
        .addClassLoader(getClass.getClassLoader)
        .setExpandSuperTypes(false)

      if (jarsUrls.nonEmpty) {
        configBuilder.addClassLoader(URLClassLoader.newInstance(jarsUrls.toArray, getClass.getClassLoader))
      }

      new Reflections(configBuilder).getTypesAnnotatedWith(classOf[Register]).toSet
    } else {
      Set()
    }

  }

  private lazy val thisJarURLOpt: Option[URL] = {
    val jarRegex = """jar:(file:.*\.jar)!.*""".r

    val url = getClass.getClassLoader.getResource(
      getClass.getCanonicalName.replaceAll("\\.", File.separator) + ".class")

    url.toString match {
      case jarRegex(jar) => Some(new URL(jar))
      case _ => None
    }
  }

  private def registerDOperation(
    registrar: CatalogRegistrar,
    operation: Class[DOperation],
    priority: SortPriority
  ): Unit = TypeUtils.constructorForClass(operation) match {
    case Some(constructor) =>
      registrar.registerOperation(
        DOperationCategories.UserDefined,
        () => TypeUtils.createInstance[DOperation](constructor),
        priority
      )
    case None => logger.error(
      s"Class $operation could not be registered." +
        "It needs to have parameterless constructor"
    )
  }

  class AssignableFromExtractor[T](targetClass: Class[T]) {
    def unapply(clazz: Class[_]): Option[Class[T]] = {
      if (targetClass.isAssignableFrom(clazz)) {
        Some(clazz.asInstanceOf[Class[T]])
      } else {
        None
      }
    }
  }

  object DOperationMatcher extends AssignableFromExtractor(classOf[DOperation])

} 
Example 2
Source File: CatalogScanner.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.refl

import java.io.File
import java.net.{URL, URLClassLoader}

import scala.collection.JavaConversions._

import org.reflections.Reflections
import org.reflections.util.ConfigurationBuilder

import io.deepsense.commons.utils.Logging
import io.deepsense.deeplang.catalogs.doperable.DOperableCatalog
import io.deepsense.deeplang.catalogs.doperations.DOperationsCatalog
import io.deepsense.deeplang.{DOperation, DOperationCategories, TypeUtils}


  def scanAndRegister(
      dOperableCatalog: DOperableCatalog,
      dOperationsCatalog: DOperationsCatalog
  ): Unit = {
    logger.info(
      s"Scanning registrables. Following jars will be scanned: ${jarsUrls.mkString(";")}.")
    for (registrable <- scanForRegistrables()) {
      logger.debug(s"Trying to register class $registrable")
      registrable match {
        case DOperationMatcher(doperation) => registerDOperation(dOperationsCatalog, doperation)
        case other => logger.warn(s"Only DOperation can be `@Register`ed")
      }
    }
  }

  private def scanForRegistrables(): Set[Class[_]] = {

    val urls = thisJarURLOpt ++ jarsUrls

    if (urls.nonEmpty) {

      val configBuilder = ConfigurationBuilder.build(urls.toSeq: _*)

      if (jarsUrls.nonEmpty) {
        configBuilder.addClassLoader(URLClassLoader.newInstance(jarsUrls.toArray))
      }

      new Reflections(configBuilder).getTypesAnnotatedWith(classOf[Register]).toSet
    } else {
      Set()
    }

  }

  private lazy val thisJarURLOpt: Option[URL] = {
    val jarRegex = """jar:(file:.*\.jar)!.*""".r

    val url = getClass.getClassLoader.getResource(
      getClass.getCanonicalName.replaceAll("\\.", File.separator) + ".class")

    url.toString match {
      case jarRegex(jar) => Some(new URL(jar))
      case _ => None
    }
  }


  private def registerDOperation(
      catalog: DOperationsCatalog,
      operation: Class[DOperation]
  ): Unit = TypeUtils.constructorForClass(operation) match {
    case Some(constructor) =>
      catalog.registerDOperation(
        DOperationCategories.UserDefined,
        () => TypeUtils.createInstance[DOperation](constructor)
      )
    case None => logger.error(
      s"Class $operation could not be registered." +
        "It needs to have parameterless constructor"
    )
  }

  class AssignableFromExtractor[T](targetClass: Class[T]) {
    def unapply(clazz: Class[_]): Option[Class[T]] = {
      if (targetClass.isAssignableFrom(clazz)) {
        Some(clazz.asInstanceOf[Class[T]])
      } else {
        None
      }
    }
  }

  object DOperationMatcher extends AssignableFromExtractor(classOf[DOperation])

} 
Example 3
Source File: ProvidersFactory.scala    From amaterasu   with Apache License 2.0 5 votes vote down vote up
package org.apache.amaterasu.executor.mesos.executors

import java.io.ByteArrayOutputStream

import org.apache.amaterasu.common.dataobjects.ExecData
import org.apache.amaterasu.common.execution.actions.Notifier
import org.apache.amaterasu.sdk.{AmaterasuRunner, RunnersProvider}
import org.reflections.Reflections

import scala.collection.JavaConversions._

//TODO: Check if we can use this in the YARN impl
class ProvidersFactory {

  var providers: Map[String, RunnersProvider] = _

  def getRunner(groupId: String, id: String): Option[AmaterasuRunner] = {
    val provider = providers.get(groupId)
    provider match {
      case Some(provider) => Some(provider.getRunner(id))
      case None => None
    }
  }
}

object ProvidersFactory {

  def apply(data: ExecData,
            jobId: String,
            outStream: ByteArrayOutputStream,
            notifier: Notifier,
            executorId: String): ProvidersFactory = {

    val result = new ProvidersFactory()
    val reflections = new Reflections(getClass.getClassLoader)
    val runnerTypes = reflections.getSubTypesOf(classOf[RunnersProvider]).toSet

    result.providers = runnerTypes.map(r => {

      val provider = Manifest.classType(r).runtimeClass.newInstance.asInstanceOf[RunnersProvider]

      notifier.info(s"a provider for group ${provider.getGroupIdentifier} was created")
      provider.init(data, jobId, outStream, notifier, executorId)
      (provider.getGroupIdentifier, provider)
    }).toMap

    result
  }

} 
Example 4
Source File: ModelSerializabilityTestBase.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha

import scala.language.existentials
import com.eharmony.aloha
import com.eharmony.aloha.models.{Model, SubmodelBase}
import org.junit.Assert._
import org.junit.Test
import org.reflections.Reflections

import scala.collection.JavaConversions.asScalaSet
import scala.util.Try
import java.lang.reflect.{Method, Modifier}

import com.eharmony.aloha.util.Logging


abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[String])
extends Logging {

  def this() = this(pkgs = Seq(aloha.pkgName), Seq.empty)

  @Test def testSerialization(): Unit = {
    val ref = new Reflections(pkgs:_*)
    val submodels = ref.getSubTypesOf(classOf[SubmodelBase[_, _, _, _]]).toSeq
    val models = ref.getSubTypesOf(classOf[Model[_, _]]).toSeq

    val modelClasses =
      (models ++ submodels).
        filterNot { _.isInterface }.
        filterNot { c =>
          val name = c.getName
          outFilters.exists(name.matches)
        }

    if (modelClasses.isEmpty) {
      fail(s"No models found to test for Serializability in packages: ${pkgs.mkString(",")}")
    }
    else {
      debug {
        modelClasses
          .map(_.getCanonicalName)
          .mkString("Models tested for Serializability:\n\t", "\n\t", "")
      }
    }

    modelClasses.foreach { c =>
      val m = for {
        testClass  <- getTestClass(c.getCanonicalName)
        testMethod <- getTestMethod(testClass)
        method     <- ensureTestMethodIsTest(testMethod)
      } yield method

      m.left foreach fail
    }
  }

  private[this] implicit class RightMonad[L, R](e: Either[L, R]) {
    def flatMap[R1](f: R => Either[L, R1]) = e.right.flatMap(f)
    def map[R1](f: R => R1) = e.right.map(f)
  }

  private[this] def getTestClass(modelClassName: String) = {
    val testName = modelClassName + "Test"
    Try {
      Class.forName(testName)
    } map {
      Right(_)
    } getOrElse Left("No test class exists for " + modelClassName)
  }

  private[this] def getTestMethod(testClass: Class[_]) = {
    val testMethodName = "testSerialization"
    lazy val msg = s"$testMethodName doesn't exist in ${testClass.getCanonicalName}."
    Try {
      Option(testClass.getMethod(testMethodName))
    } map {
      case Some(m) => Right(m)
      case None => Left(msg)
    } getOrElse Left(msg)
  }

  private[this] def ensureTestMethodIsTest(method: Method) = {
    if (!Modifier.isPublic(method.getModifiers))
      Left(s"testSerialization in ${method.getDeclaringClass.getCanonicalName} is not public")
    if (!method.getDeclaredAnnotations.exists(_.annotationType() == classOf[Test]))
      Left(s"testSerialization in ${method.getDeclaringClass.getCanonicalName} does not have a @org.junit.Test annotation.")
    else if (method.getReturnType != classOf[Void] && method.getReturnType != classOf[Unit])
      Left(s"testSerialization in ${method.getDeclaringClass.getCanonicalName} is not a void function. It returns: ${method.getReturnType}")
    else Right(method)
  }
} 
Example 5
Source File: RowCreatorProducerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset

import java.lang.reflect.Modifier

import com.eharmony.aloha
import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator
import org.junit.Assert._
import org.junit.{Ignore, Test}
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.collection.JavaConversions.asScalaSet
import org.reflections.Reflections

@RunWith(classOf[BlockJUnit4ClassRunner])
class RowCreatorProducerTest {
    import RowCreatorProducerTest._

    private[this] def scanPkg = aloha.pkgName + ".dataset"

    @Test def testAllRowCreatorProducersHaveOnlyZeroArgConstructors() {
        val reflections = new Reflections(scanPkg)
        val specProdClasses = reflections.getSubTypesOf(classOf[RowCreatorProducer[_, _, _]]).toSet
        specProdClasses.foreach { clazz =>
            val cons = clazz.getConstructors
            assertTrue(s"There should only be one constructor for ${clazz.getCanonicalName}.  Found ${cons.length} constructors.", cons.length <= 1)
            cons.headOption.foreach { c =>
                if (!(WhitelistedRowCreatorProducers contains clazz)) {
                    val nParams = c.getParameterTypes.length
                    assertEquals(s"The constructor for ${clazz.getCanonicalName} should take 0 arguments.  It takes $nParams.", 0, nParams)
                }
            }
        }
    }

    
    // TODO: Report the above bug!
    @Ignore @Test def testAllRowCreatorProducersAreFinalClasses() {
        val reflections = new Reflections(scanPkg)
        val specProdClasses = reflections.getSubTypesOf(classOf[RowCreatorProducer[_, _, _]]).toSet
        specProdClasses.foreach { clazz =>
            assertTrue(s"${clazz.getCanonicalName} needs to be declared final.", Modifier.isFinal(clazz.getModifiers))
        }
    }
}

object RowCreatorProducerTest {
    private val WhitelistedRowCreatorProducers = Set[Class[_]](
        classOf[VwMultilabelRowCreator.Producer[_, _]]
    )
} 
Example 6
Source File: RuntimeClasspathScanning.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.reflect

import com.eharmony.aloha
import org.reflections.Reflections

import scala.reflect.{classTag, ClassTag}
import scala.util.Try


  protected[this] def scanObjects[OBJ: ClassTag, A: ClassTag](
      methodName: String,
      packageToSearch: String = aloha.pkgName
  ): Seq[A] = {
    val reflections = new Reflections(aloha.pkgName)
    import scala.collection.JavaConversions.asScalaSet
    val objects = reflections.getSubTypesOf(classTag[OBJ].runtimeClass).toSeq

    val suffixLength = objectSuffix.length

    objects.flatMap {
      case o if isObject(o) =>
        Try {
          // This may have some classloading issues.
          val classObj = Class.forName(o.getCanonicalName.dropRight(suffixLength))
          classObj.getMethod(methodName).invoke(null) match {
            case a: A => a
            case _ => throw new IllegalStateException()
          }
        }.toOption
      case _ => None
    }
  }
} 
Example 7
Source File: Cli.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.cli

import java.lang.reflect.Modifier

import com.eharmony.aloha.annotate.CLI
import org.reflections.Reflections

import scala.collection.JavaConversions.asScalaSet
import com.eharmony.aloha.pkgName


object Cli {
    def main(args: Array[String]): Unit = {

        if (args.isEmpty) {
            Console.err.println("No arguments supplied. Supply one of: " + flagClassMap.keys.toVector.sorted.map("'" + _ + "'").mkString(", ") + ".")
        }
        else {
            val flag = args(0)
            if (!flagClassMap.contains(flag)) {
                Console.err.println(s"'$flag' supplied. Supply one of: " + flagClassMap.keys.toVector.sorted.map("'" + _ + "'").mkString(", ") + ".")
            }
            else {
                flagClassMap(flag).
                    getMethod("main", classOf[Array[String]]).
                    invoke(null, args.tail)
            }
        }
    }

    private[cli] lazy val cliClasses = {
        val reflections = new Reflections(pkgName)

        // We want to classes with the static forwarders, not the singleton (module) classes.
        reflections.getTypesAnnotatedWith(classOf[CLI]).toSet.asInstanceOf[Set[Class[Any]]].collect { case c if hasStaticMain(c) =>  c }
    }

    private[this] def hasStaticMain(c: Class[Any]): Boolean =
        !c.getName.endsWith("$") && Option(c.getMethod("main", classOf[Array[String]])).exists(m => Modifier.isStatic(m.getModifiers))


    private[this] lazy val flagClassMap = cliClasses.map{ case c => c.getAnnotation(classOf[CLI]).flag() -> c }.toMap
} 
Example 8
Source File: ProjectGenerator.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.cli.gen

import java.io.InputStream

import com.salesforce.op.cli.GeneratorConfig
import com.salesforce.op.cli.gen.FileSource.{Str, Streaming}
import com.salesforce.op.cli.gen.templates.SimpleProject
import org.reflections.Reflections
import org.reflections.scanners.ResourcesScanner
import org.scalafmt.Scalafmt
import org.scalafmt.config.ScalafmtConfig

import collection.JavaConverters._
import scala.io.Source

trait ProjectGenerator {

  
  def shouldCopy(path: String): Boolean = false

  private lazy val templateResourcePath = s"templates/$name/"

  lazy val templateResources: List[String] =
    new Reflections(s"templates.$name", new ResourcesScanner)
      .getResources(".*".r.pattern)
      .asScala.toList.map(_.replace(templateResourcePath, ""))

  private def loadResource(templatePath: String): InputStream = {
    val resourcePath = s"/$templateResourcePath$templatePath"
    getClass.getResourceAsStream(resourcePath)
  }

  private def loadTemplateFileSource(templatePath: String): String =
    Source.fromInputStream(loadResource(templatePath)).getLines.map(_ + "\n").mkString

  private lazy val templates: Map[String, FileGenerator] = templateResources.filterNot(shouldCopy).map { path =>
    path -> new FileGenerator(path, loadTemplateFileSource(path))
  }.toMap

  private lazy val copies: List[FileInProject] = templateResources.filter(shouldCopy).map { path =>
    FileInProject(path = path, source = Streaming(loadResource(path)))
  }

  def templateFile(name: String): FileGenerator = templates(name)

  def renderAll(
    substitutions: FileGenerator.Substitutions, formatScala: Boolean = true
  ): Traversable[FileInProject] = {

    val rendered = templates map {
      case (_, tpl) => tpl.render(substitutions)
    }

    val formatted =
      if (formatScala) rendered map {
        case f @ FileInProject(path, Str(source), _) if path.endsWith(".scala") =>
          f.copy(source = Str(Scalafmt.format(source, new ScalafmtConfig(maxColumn = 120)).get))
        case f => f
      }
      else rendered

    val allFiles = formatted ++ copies

    allFiles map replace
  }

}

object ProjectGenerator {
  val byName: Map[String, Ops => ProjectGenerator] =
    Map(
    "simple" -> SimpleProject
  )

} 
Example 9
Source File: ReportAndLogSupport.scala    From RTran   with Apache License 2.0 5 votes vote down vote up
package com.ebay.rtran.report

import java.io.{File, FileOutputStream}

import ch.qos.logback.classic
import ch.qos.logback.classic.encoder.PatternLayoutEncoder
import ch.qos.logback.classic.filter.ThresholdFilter
import ch.qos.logback.classic.spi.ILoggingEvent
import ch.qos.logback.core.FileAppender
import org.reflections.Reflections
import com.ebay.rtran.report.api.IReportEventSubscriber
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.JavaConversions._
import scala.language.postfixOps
import scala.util.{Success, Try}


trait ReportAndLogSupport {

  val reportFilePrefix: String
  val warnLogPrefix: String
  val debugLogPrefix: String

  def createReportAndLogs[T](projectRoot: File,
                             taskId: Option[String], packages: String*)(fn: => T): T = {
    val appenders = prepareAppenders(projectRoot, taskId)
    val rootLogger = LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME).asInstanceOf[classic.Logger]
    appenders foreach rootLogger.addAppender
    appenders foreach (_.start)
    val reportFile = new File(projectRoot, s"$reportFilePrefix${taskId.map("-" + _) getOrElse ""}.md")
    val result = Report.createReport(
      new FileOutputStream(reportFile),
      subscribers = allSubscribers(projectRoot, packages: _*)
    )(fn)
    appenders foreach (_.stop)
    appenders foreach rootLogger.detachAppender
    result
  }

  def allSubscribers(projectRoot: File, packages: String*) = {
    val subscribers = packages flatMap {prefix =>
      new Reflections(prefix).getSubTypesOf(classOf[IReportEventSubscriber[_]])
    } map {clazz =>
      Try(clazz.getDeclaredConstructor(classOf[File]).newInstance(projectRoot)) orElse Try(clazz.newInstance)
    } collect {
      case Success(subscriber) => subscriber
    } toList

    subscribers.sortBy(_.sequence)
  }

  private def prepareAppenders(projectRoot: File, taskId: Option[String]) = {
    val lc = LoggerFactory.getILoggerFactory.asInstanceOf[classic.LoggerContext]
    val encoders = Array(new PatternLayoutEncoder, new PatternLayoutEncoder)
    encoders foreach (_ setContext lc)
    encoders foreach (_ setPattern "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}.%M:%L - %m%n")
    encoders foreach (_.start)

    val warnFileAppender = new FileAppender[ILoggingEvent]
    warnFileAppender.setName("warnFileAppender")
    warnFileAppender.setFile(s"${projectRoot.getAbsolutePath}/$warnLogPrefix${taskId.map("-" + _) getOrElse ""}.log")
    warnFileAppender.addFilter(new SameThreadFilter)
    val warnFilter = new ThresholdFilter
    warnFilter.setLevel("WARN")
    warnFilter.start()
    warnFileAppender.addFilter(warnFilter)

    val debugFileAppender = new FileAppender[ILoggingEvent]
    debugFileAppender.setName("debugFileAppender")
    debugFileAppender.setFile(s"${projectRoot.getAbsolutePath}/$debugLogPrefix${taskId.map("-" + _) getOrElse ""}.log")
    debugFileAppender.addFilter(new SameThreadFilter)
    val debugFilter = new ThresholdFilter
    debugFilter.setLevel("DEBUG")
    debugFilter.start()
    debugFileAppender.addFilter(debugFilter)

    val result = List(warnFileAppender, debugFileAppender)
    result.foreach(_ setContext lc)
    result zip encoders foreach (entry => entry._1 setEncoder entry._2)
    result
  }
} 
Example 10
Source File: ReflectionUtils.scala    From ohara   with Apache License 2.0 5 votes vote down vote up
package oharastream.ohara.configurator

import java.lang.reflect.Modifier
import com.typesafe.scalalogging.Logger
import oharastream.ohara.client.configurator.FileInfoApi.ClassInfo
import oharastream.ohara.common.setting.WithDefinitions
import oharastream.ohara.kafka.connector.{RowSinkConnector, RowSourceConnector}
import org.reflections.Reflections
import org.reflections.util.{ClasspathHelper, ConfigurationBuilder}

import scala.jdk.CollectionConverters._
object ReflectionUtils {
  private[this] val LOG = Logger(ReflectionUtils.getClass)

  
  lazy val localConnectorDefinitions: Seq[ClassInfo] =
    new Reflections(
      new ConfigurationBuilder()
      // we ought to define urls manually since Reflections does not work on java 11
      // It can't find correct urls without pre-defined urls.
        .setUrls(ClasspathHelper.forJavaClassPath)
    ).getSubTypesOf(classOf[WithDefinitions])
      .asScala
      .toSeq
      .filter(
        clz => classOf[RowSourceConnector].isAssignableFrom(clz) || classOf[RowSinkConnector].isAssignableFrom(clz)
      )
      // the abstract class is not instantiable.
      .filterNot(clz => Modifier.isAbstract(clz.getModifiers))
      .flatMap { clz =>
        try Some((clz.getName, clz.getDeclaredConstructor().newInstance().settingDefinitions().values().asScala.toSeq))
        catch {
          case e: Throwable =>
            LOG.error(s"failed to instantiate ${clz.getName}", e)
            None
        }
      }
      .map {
        case (className, definitions) =>
          ClassInfo(
            className = className,
            settingDefinitions = definitions
          )
      }
} 
Example 11
Source File: Run.scala    From codepropertygraph   with Apache License 2.0 5 votes vote down vote up
package io.shiftleft.console

import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.passes.{CpgPass, DiffGraph}
import io.shiftleft.semanticcpg.language.HasStoreMethod
import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext}
import org.reflections.Reflections

import scala.jdk.CollectionConverters._

object Run {

  def runCustomQuery(console: Console[_], query: HasStoreMethod): Unit = {
    console._runAnalyzer(
      new LayerCreator {
        override val overlayName: String = "custom"
        override val description: String = "A custom pass"

        override def create(context: LayerCreatorContext, serializeInverse: Boolean): Unit = {
          val serializedCpg = initSerializedCpg(context.outputDir, "custom", 0)
          val pass = new CpgPass(console.cpg) {
            override def run(): Iterator[DiffGraph] = {
              implicit val diffGraph: DiffGraph.Builder = DiffGraph.newBuilder
              query.store
              Iterator(diffGraph.build())
            }
          }
          pass.createApplySerializeAndStore(serializedCpg, inverse = true, "custom")
          serializedCpg.close()
        }
        override def probe(cpg: Cpg): Boolean = false
      }
    )
  }

  
  def codeForRunCommand(exclude: List[String] = List()): String = {
    val r = new Reflections("io.shiftleft")
    val layerCreatorTypeNames = r
      .getSubTypesOf(classOf[LayerCreator])
      .asScala
      .filterNot(t => t.isAnonymousClass || t.isLocalClass || t.isMemberClass || t.isSynthetic)
      .filterNot(t => t.getName.startsWith("io.shiftleft.console.Run"))
      .toList
      .map(t => (t.getSimpleName.toLowerCase, t.getName))
      .filter(t => !exclude.contains(t._2))

    val optsMembersCode = layerCreatorTypeNames
      .map { case (varName, typeName) => s"val $varName = $typeName.defaultOpts" }
      .mkString("\n")

    val optsCode =
      s"""
        |class OptsDynamic {
        | $optsMembersCode
        |}
        |
        |val opts = new OptsDynamic()
        |
        | import io.shiftleft.passes.DiffGraph
        | implicit def _diffGraph : DiffGraph.Builder = opts.commit.diffGraphBuilder
        | def diffGraph = _diffGraph
        |""".stripMargin

    val membersCode = layerCreatorTypeNames
      .map { case (varName, typeName) => s"def $varName: Cpg = _runAnalyzer(new $typeName(opts.$varName))" }
      .mkString("\n")

    val toStringCode =
      s"""
         | import overflowdb.traversal.help.Table
         | override def toString() : String = {
         |  val columnNames = List("name", "description")
         |  val rows =
         |   ${layerCreatorTypeNames.map {
           case (varName, typeName) =>
             s"""List("$varName",$typeName.description.trim)"""
         }}
         | "\\n" + Table(columnNames, rows).render
         | }
         |""".stripMargin

    optsCode +
      s"""
       | class OverlaysDynamic {
       |
       | def apply(query : io.shiftleft.semanticcpg.language.HasStoreMethod) {
       |   io.shiftleft.console.Run.runCustomQuery(console, query)
       | }
       |
       | $membersCode
       |
       | $toStringCode
       | }
       | val run = new OverlaysDynamic()
       |""".stripMargin
  }

} 
Example 12
Source File: ReflectionUtils.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.core.utils

import java.io.Serializable
import java.net.URLClassLoader

import akka.event.slf4j.SLF4JLogging
import com.stratio.sparta.sdk.pipeline.aggregation.cube.DimensionType
import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import com.stratio.sparta.sdk.pipeline.input.Input
import com.stratio.sparta.sdk.pipeline.output.Output
import com.stratio.sparta.sdk.pipeline.transformation.Parser
import org.reflections.Reflections
import com.stratio.sparta.serving.core.exception.ServingCoreException

import scala.collection.JavaConversions._

class ReflectionUtils extends SLF4JLogging {

  def tryToInstantiate[C](classAndPackage: String, block: Class[_] => C): C = {
    val clazMap: Map[String, String] = getClasspathMap
    val finalClazzToInstance = clazMap.getOrElse(classAndPackage, classAndPackage)
    try {
      val clazz = Class.forName(finalClazzToInstance)
      block(clazz)
    } catch {
      case cnfe: ClassNotFoundException =>
        throw ServingCoreException.create(
          "Class with name " + classAndPackage + " Cannot be found in the classpath.", cnfe)
      case ie: InstantiationException =>
        throw ServingCoreException.create("Class with name " + classAndPackage + " cannot be instantiated", ie)
      case e: Exception =>
        throw ServingCoreException.create("Generic error trying to instantiate " + classAndPackage, e)
    }
  }

  def instantiateParameterizable[C](clazz: Class[_], properties: Map[String, Serializable]): C =
    clazz.getDeclaredConstructor(classOf[Map[String, Serializable]]).newInstance(properties).asInstanceOf[C]

  def printClassPath(cl: ClassLoader): Unit = {
    val urls = cl.asInstanceOf[URLClassLoader].getURLs()
    urls.foreach(url => log.debug(url.getFile))
  }

  lazy val getClasspathMap: Map[String, String] = {
    val reflections = new Reflections("com.stratio.sparta")

    try {
      log.debug("#######")
      log.debug("####### SPARK MUTABLE_URL_CLASS_LOADER:")
      log.debug(getClass.getClassLoader.toString)
      printClassPath(getClass.getClassLoader)
      log.debug("#######")
      log.debug("####### APP_CLASS_LOADER / SYSTEM CLASSLOADER:")
      log.debug(ClassLoader.getSystemClassLoader().toString)
      printClassPath(ClassLoader.getSystemClassLoader())
      log.debug("#######")
      log.debug("####### EXTRA_CLASS_LOADER:")
      log.debug(getClass.getClassLoader.getParent.getParent.toString)
      printClassPath(getClass.getClassLoader.getParent.getParent)
    } catch {
      case e: Exception => //nothing
    }

    val inputs = reflections.getSubTypesOf(classOf[Input]).toList
    val dimensionTypes = reflections.getSubTypesOf(classOf[DimensionType]).toList
    val operators = reflections.getSubTypesOf(classOf[Operator]).toList
    val outputs = reflections.getSubTypesOf(classOf[Output]).toList
    val parsers = reflections.getSubTypesOf(classOf[Parser]).toList
    val plugins = inputs ++ dimensionTypes ++ operators ++ outputs ++ parsers
    val result = plugins map (t => t.getSimpleName -> t.getCanonicalName) toMap

    log.debug("#######")
    log.debug("####### Plugins to be loaded:")
    result.foreach {
      case (simpleName: String, canonicalName: String) => log.debug(s"$canonicalName")
    }

    result
  }
} 
Example 13
Source File: Module.scala    From Cortex   with GNU Affero General Public License v3.0 5 votes vote down vote up
package org.thp.cortex

import java.lang.reflect.Modifier

import com.google.inject.AbstractModule
import net.codingwell.scalaguice.{ScalaModule, ScalaMultibinder}
import play.api.libs.concurrent.AkkaGuiceSupport
import play.api.{Configuration, Environment, Logger, Mode}
import scala.collection.JavaConverters._

import com.google.inject.name.Names
import org.reflections.Reflections
import org.reflections.scanners.SubTypesScanner
import org.reflections.util.ConfigurationBuilder
import org.thp.cortex.models.{AuditedModel, Migration}
import org.thp.cortex.services._

import org.elastic4play.models.BaseModelDef
import org.elastic4play.services.auth.MultiAuthSrv
import org.elastic4play.services.{UserSrv ⇒ EUserSrv, AuthSrv, MigrationOperations}
import org.thp.cortex.controllers.{AssetCtrl, AssetCtrlDev, AssetCtrlProd}
import services.mappers.{MultiUserMapperSrv, UserMapper}

class Module(environment: Environment, configuration: Configuration) extends AbstractModule with ScalaModule with AkkaGuiceSupport {

  private lazy val logger = Logger(s"module")

  override def configure(): Unit = {
    val modelBindings        = ScalaMultibinder.newSetBinder[BaseModelDef](binder)
    val auditedModelBindings = ScalaMultibinder.newSetBinder[AuditedModel](binder)
    val reflectionClasses = new Reflections(
      new ConfigurationBuilder()
        .forPackages("org.elastic4play")
        .addClassLoader(getClass.getClassLoader)
        .addClassLoader(environment.getClass.getClassLoader)
        .forPackages("org.thp.cortex")
        .setExpandSuperTypes(false)
        .setScanners(new SubTypesScanner(false))
    )

    reflectionClasses
      .getSubTypesOf(classOf[BaseModelDef])
      .asScala
      .filterNot(c ⇒ Modifier.isAbstract(c.getModifiers))
      .foreach { modelClass ⇒
        logger.info(s"Loading model $modelClass")
        modelBindings.addBinding.to(modelClass)
        if (classOf[AuditedModel].isAssignableFrom(modelClass)) {
          auditedModelBindings.addBinding.to(modelClass.asInstanceOf[Class[AuditedModel]])
        }
      }

    val authBindings = ScalaMultibinder.newSetBinder[AuthSrv](binder)
    reflectionClasses
      .getSubTypesOf(classOf[AuthSrv])
      .asScala
      .filterNot(c ⇒ Modifier.isAbstract(c.getModifiers) || c.isMemberClass)
      .filterNot(c ⇒ c == classOf[MultiAuthSrv] || c == classOf[CortexAuthSrv])
      .foreach { authSrvClass ⇒
        logger.info(s"Loading authentication module $authSrvClass")
        authBindings.addBinding.to(authSrvClass)
      }

    val ssoMapperBindings = ScalaMultibinder.newSetBinder[UserMapper](binder)
    reflectionClasses
      .getSubTypesOf(classOf[UserMapper])
      .asScala
      .filterNot(c ⇒ Modifier.isAbstract(c.getModifiers) || c.isMemberClass)
      .filterNot(c ⇒ c == classOf[MultiUserMapperSrv])
      .foreach(mapperCls ⇒ ssoMapperBindings.addBinding.to(mapperCls))

    if (environment.mode == Mode.Prod)
      bind[AssetCtrl].to[AssetCtrlProd]
    else
      bind[AssetCtrl].to[AssetCtrlDev]

    bind[EUserSrv].to[UserSrv]
    bind[Int].annotatedWith(Names.named("databaseVersion")).toInstance(models.modelVersion)
    bind[UserMapper].to[MultiUserMapperSrv]

    bind[AuthSrv].to[CortexAuthSrv]
    bind[MigrationOperations].to[Migration]
    bindActor[AuditActor]("audit")
  }
} 
Example 14
Source File: LazyApplicationSuite.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.app.mock

import java.lang.reflect.Modifier
import java.nio.ByteOrder

import com.typesafe.config.{Config, ConfigFactory}
import it.agilelab.darwin.annotations.AvroSerde
import it.agilelab.darwin.app.mock.classes.{MyClass, MyNestedClass, NewClass, OneField}
import it.agilelab.darwin.common.{Connector, ConnectorFactory}
import it.agilelab.darwin.manager.{AvroSchemaManager, LazyAvroSchemaManager}
import org.apache.avro.{Schema, SchemaNormalization}
import org.apache.avro.reflect.ReflectData
import org.reflections.Reflections

import it.agilelab.darwin.common.compat._
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class BigEndianLazyApplicationSuite extends LazyApplicationSuite(ByteOrder.BIG_ENDIAN)

class LittleEndianLazyApplicationSuite extends LazyApplicationSuite(ByteOrder.LITTLE_ENDIAN)

abstract class LazyApplicationSuite(endianness: ByteOrder) extends AnyFlatSpec with Matchers {

  val config: Config = ConfigFactory.load()
  val connector: Connector = ConnectorFactory.connector(config)
  val manager: AvroSchemaManager = new LazyAvroSchemaManager(connector, endianness)

  "LazyAvroSchemaManager" should "not fail after the initialization" in {
    val schemas: Seq[Schema] = Seq(SchemaReader.readFromResources("MyNestedClass.avsc"))
    assert(manager.registerAll(schemas).size == 1)
  }

  it should "load all existing schemas and register a new one" in {
    val schemas: Seq[Schema] = Seq(SchemaReader.readFromResources("MyNestedClass.avsc"))
    manager.getSchema(0L)

    manager.registerAll(schemas)

    val id = manager.getId(schemas.head)
    assert(manager.getSchema(id).isDefined)
    assert(schemas.head == manager.getSchema(id).get)
  }

  it should "get all previously registered schemas" in {
    val schema: Schema = SchemaReader.readFromResources("MyNestedClass.avsc")
    val schema0 = manager.getSchema(0L)
    val schema1 = manager.getSchema(1L)
    assert(schema0.isDefined)
    assert(schema1.isDefined)
    assert(schema0.get != schema1.get)
    assert(schema != schema0.get)
    assert(schema != schema1.get)
  }

  it should "generate all schemas for all the annotated classes with @AvroSerde" in {
    val reflections = new Reflections("it.agilelab.darwin.app.mock.classes")

    val oneFieldSchema = ReflectData.get().getSchema(classOf[OneField]).toString
    val myNestedSchema = ReflectData.get().getSchema(classOf[MyNestedClass]).toString
    val myClassSchema = ReflectData.get().getSchema(classOf[MyClass]).toString

    val annotationClass: Class[AvroSerde] = classOf[AvroSerde]
    val classes = reflections.getTypesAnnotatedWith(annotationClass).toScala.toSeq
      .filter(c => !c.isInterface && !Modifier.isAbstract(c.getModifiers))
    val schemas = classes.map(c => ReflectData.get().getSchema(Class.forName(c.getName)).toString)
    Seq(oneFieldSchema, myClassSchema, myNestedSchema) should contain theSameElementsAs schemas
  }

  it should "reload all schemas from the connector" in {
    val newSchema = ReflectData.get().getSchema(classOf[NewClass])
    val newId = SchemaNormalization.parsingFingerprint64(newSchema)
    assert(manager.getSchema(newId).isEmpty)

    connector.insert(Seq(newId -> newSchema))
    assert(manager.getSchema(newId).isDefined)
    assert(manager.getSchema(newId).get == newSchema)
  }
} 
Example 15
Source File: CachedLazyApplicationSuite.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.app.mock

import java.lang.reflect.Modifier
import java.nio.ByteOrder

import com.typesafe.config.{Config, ConfigFactory}
import it.agilelab.darwin.annotations.AvroSerde
import it.agilelab.darwin.app.mock.classes.{MyClass, MyNestedClass, NewClass, OneField}
import it.agilelab.darwin.common.{Connector, ConnectorFactory}
import it.agilelab.darwin.manager.{AvroSchemaManager, CachedLazyAvroSchemaManager}
import org.apache.avro.{Schema, SchemaNormalization}
import org.apache.avro.reflect.ReflectData
import org.reflections.Reflections

import it.agilelab.darwin.common.compat._
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class BigEndianCachedLazyApplicationSuite extends CachedLazyApplicationSuite(ByteOrder.BIG_ENDIAN)

class LittleEndianCachedLazyApplicationSuite extends CachedLazyApplicationSuite(ByteOrder.LITTLE_ENDIAN)

abstract class CachedLazyApplicationSuite(val endianness: ByteOrder) extends AnyFlatSpec with Matchers {

  val config: Config = ConfigFactory.load()
  val connector: Connector = ConnectorFactory.connector(config)
  val manager: AvroSchemaManager = new CachedLazyAvroSchemaManager(connector, endianness)

  "CachedLazyAvroSchemaManager" should "not fail after the initialization" in {
    val schemas: Seq[Schema] = Seq(SchemaReader.readFromResources("MyNestedClass.avsc"))
    assert(manager.registerAll(schemas).size == 1)
  }

  it should "load all existing schemas and register a new one" in {
    val schemas: Seq[Schema] = Seq(SchemaReader.readFromResources("MyNestedClass.avsc"))
    manager.getSchema(0L)

    manager.registerAll(schemas)

    val id = manager.getId(schemas.head)
    assert(manager.getSchema(id).isDefined)
    assert(schemas.head == manager.getSchema(id).get)
  }

  it should "get all previously registered schemas" in {
    val schema: Schema = SchemaReader.readFromResources("MyNestedClass.avsc")
    val schema0 = manager.getSchema(0L)
    val schema1 = manager.getSchema(1L)
    assert(schema0.isDefined)
    assert(schema1.isDefined)
    assert(schema0.get != schema1.get)
    assert(schema != schema0.get)
    assert(schema != schema1.get)
  }

  it should "generate all schemas for all the annotated classes with @AvroSerde" in {
    val reflections = new Reflections("it.agilelab.darwin.app.mock.classes")

    val oneFieldSchema = ReflectData.get().getSchema(classOf[OneField]).toString
    val myNestedSchema = ReflectData.get().getSchema(classOf[MyNestedClass]).toString
    val myClassSchema = ReflectData.get().getSchema(classOf[MyClass]).toString

    val annotationClass: Class[AvroSerde] = classOf[AvroSerde]
    val classes = reflections.getTypesAnnotatedWith(annotationClass).toScala.toSeq
      .filter(c => !c.isInterface && !Modifier.isAbstract(c.getModifiers))
    val schemas = classes.map(c => ReflectData.get().getSchema(Class.forName(c.getName)).toString)
    Seq(oneFieldSchema, myClassSchema, myNestedSchema) should contain theSameElementsAs schemas
  }

  it should "reload all schemas from the connector" in {
    val newSchema = ReflectData.get().getSchema(classOf[NewClass])
    val newId = SchemaNormalization.parsingFingerprint64(newSchema)
    assert(manager.getSchema(newId).isEmpty)

    connector.insert(Seq(newId -> newSchema))
    assert(manager.getSchema(newId).isDefined)
    assert(manager.getSchema(newId).get == newSchema)
  }
} 
Example 16
Source File: CachedEagerApplicationSuite.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.app.mock

import java.lang.reflect.Modifier
import java.nio.ByteOrder

import com.typesafe.config.{Config, ConfigFactory}
import it.agilelab.darwin.annotations.AvroSerde
import it.agilelab.darwin.app.mock.classes.{MyClass, MyNestedClass, NewClass, OneField}
import it.agilelab.darwin.common.{Connector, ConnectorFactory}
import it.agilelab.darwin.manager.{AvroSchemaManager, CachedEagerAvroSchemaManager}
import org.apache.avro.{Schema, SchemaNormalization}
import org.apache.avro.reflect.ReflectData
import org.reflections.Reflections

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import it.agilelab.darwin.common.compat._

class BigEndianCachedEagerApplicationSuite extends CachedEagerApplicationSuite(ByteOrder.BIG_ENDIAN)

class LittleEndianCachedEagerApplicationSuite extends CachedEagerApplicationSuite(ByteOrder.LITTLE_ENDIAN)

abstract class CachedEagerApplicationSuite(val endianness: ByteOrder) extends AnyFlatSpec with Matchers {

  val config: Config = ConfigFactory.load()
  val connector: Connector = ConnectorFactory.connector(config)
  val manager: AvroSchemaManager = new CachedEagerAvroSchemaManager(connector, endianness)

  "CachedEagerAvroSchemaManager" should "not fail after the initialization" in {
    val schemas: Seq[Schema] = Seq(SchemaReader.readFromResources("MyNestedClass.avsc"))
    assert(manager.registerAll(schemas).size == 1)
  }

  it should "load all existing schemas and register a new one" in {
    val schemas: Seq[Schema] = Seq(SchemaReader.readFromResources("MyNestedClass.avsc"))
    manager.getSchema(0L)

    manager.registerAll(schemas)

    val id = manager.getId(schemas.head)
    assert(manager.getSchema(id).isDefined)
    assert(schemas.head == manager.getSchema(id).get)
  }

  it should "get all previously registered schemas" in {
    val schema: Schema = SchemaReader.readFromResources("MyNestedClass.avsc")
    val schema0 = manager.getSchema(0L)
    val schema1 = manager.getSchema(1L)
    assert(schema0.isDefined)
    assert(schema1.isDefined)
    assert(schema0.get != schema1.get)
    assert(schema != schema0.get)
    assert(schema != schema1.get)
  }

  it should "generate all schemas for all the annotated classes with @AvroSerde" in {
    val reflections = new Reflections("it.agilelab.darwin.app.mock.classes")

    val oneFieldSchema = ReflectData.get().getSchema(classOf[OneField]).toString
    val myNestedSchema = ReflectData.get().getSchema(classOf[MyNestedClass]).toString
    val myClassSchema = ReflectData.get().getSchema(classOf[MyClass]).toString

    val annotationClass: Class[AvroSerde] = classOf[AvroSerde]
    val classes = reflections.getTypesAnnotatedWith(annotationClass).toScala.toSeq
      .filter(c => !c.isInterface && !Modifier.isAbstract(c.getModifiers))
    val schemas = classes.map(c => ReflectData.get().getSchema(Class.forName(c.getName)).toString)
    Seq(oneFieldSchema, myClassSchema, myNestedSchema) should contain theSameElementsAs schemas
  }

  it should "reload all schemas from the connector" in {
    val newSchema = ReflectData.get().getSchema(classOf[NewClass])
    val newId = SchemaNormalization.parsingFingerprint64(newSchema)
    assert(manager.getSchema(newId).isEmpty)

    connector.insert(Seq(newId -> newSchema))
    assert(manager.getSchema(newId).isEmpty)

    manager.reload()
    assert(manager.getSchema(newId).isDefined)
    assert(manager.getSchema(newId).get == newSchema)
  }

} 
Example 17
Source File: DWSHttpMessageFactory.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
object DWSHttpMessageFactory {

  private val reflections = new Reflections("com.webank.wedatasphere", classOf[DWSHttpMessageResult].getClassLoader)

  private val methodToHttpMessageClasses = reflections.getTypesAnnotatedWith(classOf[DWSHttpMessageResult])
    .filter(ClassUtils.isAssignable(_, classOf[Result])).map { c =>
    val httpMessageResult = c.getAnnotation(classOf[DWSHttpMessageResult])
    httpMessageResult.value() -> DWSHttpMessageResultInfo(httpMessageResult.value(), c)
  }.toMap
  private val methodRegex = methodToHttpMessageClasses.keys.toArray

  def getDWSHttpMessageResult(method: String): Option[DWSHttpMessageResultInfo] = methodToHttpMessageClasses.get(method).orElse {
    methodRegex.find(method.matches).map(methodToHttpMessageClasses.apply)
  }

}
case class DWSHttpMessageResultInfo(method: String, clazz: Class[_]) 
Example 18
Source File: SerializerSpecHelper.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.serializer

import java.io.{File}
import java.lang.reflect.Modifier

import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.ops.{Exp => ExpOps, Pow => PowOps, Select => SelectOps, Sum => SumOps, Tile => TileOps}
import com.intel.analytics.bigdl.nn.tf.{DecodeGif => DecodeGifOps, DecodeJpeg => DecodeJpegOps, DecodePng => DecodePngOps, DecodeRaw => DecodeRawOps}
import com.intel.analytics.bigdl.utils.RandomGenerator.RNG
import com.intel.analytics.bigdl.utils.tf.loaders.{Pack => _}
import com.intel.analytics.bigdl.utils.{Shape => KShape}
import org.reflections.Reflections
import org.reflections.scanners.SubTypesScanner
import org.reflections.util.{ClasspathHelper, ConfigurationBuilder, FilterBuilder}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

import scala.collection.JavaConverters._
import scala.collection.mutable


abstract class SerializerSpecHelper extends FlatSpec with Matchers with BeforeAndAfterAll{

  val postFix = "bigdl"
  val excludedClass = new mutable.HashSet[String]()
  val excludedPackage = new mutable.HashSet[String]()

  private val expected = new mutable.HashSet[String]()
  val tested = new mutable.HashSet[String]()

  private var executedCount = 0

  protected def getPackage(): String = ""

  protected def addExcludedClass(): Unit = {}

  protected def addExcludedPackage(): Unit = {}

  protected def getExpected(): mutable.Set[String] = expected

  override protected def beforeAll() = {
    addExcludedClass
    addExcludedPackage
    val filterBuilder = new FilterBuilder()
    excludedPackage.foreach(filterBuilder.excludePackage(_))
    val reflections = new Reflections(new ConfigurationBuilder()
      .filterInputsBy(filterBuilder)
      .setUrls(ClasspathHelper.forPackage(getPackage()))
      .setScanners(new SubTypesScanner()))
    val subTypes = reflections.getSubTypesOf(classOf[AbstractModule[_, _, _]])
      .asScala.filter(sub => !Modifier.isAbstract(sub.getModifiers)).
      filter(sub => !excludedClass.contains(sub.getName))
    subTypes.foreach(sub => expected.add(sub.getName))
  }

  protected def runSerializationTest(module : AbstractModule[_, _, Float],
                                   input : Activity, cls: Class[_] = null) : Unit = {
    runSerializationTestWithMultiClass(module, input,
      if (cls == null) Array(module.getClass) else Array(cls))
  }

  protected def runSerializationTestWithMultiClass(module : AbstractModule[_, _, Float],
      input : Activity, classes: Array[Class[_]]) : Unit = {
    val name = module.getName
    val serFile = File.createTempFile(name, postFix)
    val originForward = module.evaluate().forward(input)

    ModulePersister.saveToFile[Float](serFile.getAbsolutePath, null, module.evaluate(), true)
    RNG.setSeed(1000)
    val loadedModule = ModuleLoader.loadFromFile[Float](serFile.getAbsolutePath)

    val afterLoadForward = loadedModule.forward(input)

    if (serFile.exists) {
      serFile.delete
    }

    afterLoadForward should be (originForward)
    classes.foreach(cls => {
      if (getExpected.contains(cls.getName)) {
        tested.add(cls.getName)
      }
    })
  }


  override protected def afterAll() = {
    println(s"total ${getExpected.size}, remaining ${getExpected.size - tested.size}")
    tested.filter(!getExpected.contains(_)).foreach(t => {
      println(s"$t do not need to be tested")
    })
    getExpected.foreach(exp => {
      require(tested.contains(exp), s" $exp not included in the test!")
    })
  }
}