Python pyspark.SparkContext.getOrCreate() Examples

The following are 25 code examples of pyspark.SparkContext.getOrCreate(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pyspark.SparkContext , or try the search function .
Example #1
Source File: spark.py    From qb with MIT License 6 votes vote down vote up
def create_spark_context(app_name="Quiz Bowl", configs=None) -> SparkContext:
    if QB_SPARK_MASTER != "":
        log.info("Spark master is %s" % QB_SPARK_MASTER)
        spark_conf = SparkConf()\
            .set('spark.rpc.message.maxSize', 300)\
            .setAppName(app_name)\
            .setMaster(QB_SPARK_MASTER)
    else:
        spark_conf = SparkConf()\
            .set('spark.rpc.message.maxSize', 300)\
            .setAppName(app_name)
    if configs is not None:
        for key, value in configs:
            if key in ('spark.executor.cores', 'spark.max.cores'):
                if value > QB_MAX_CORES:
                    log.info('Requested {r_cores} cores when the machine only has {n_cores} cores, reducing number of '
                             'cores to {n_cores}'.format(r_cores=value, n_cores=QB_MAX_CORES))
                    value = QB_MAX_CORES
            spark_conf = spark_conf.set(key, value)
    return SparkContext.getOrCreate(spark_conf) 
Example #2
Source File: spark.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def _load_pyfunc(path):
    """
    Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.

    :param path: Local filesystem path to the MLflow Model with the ``spark`` flavor.
    """
    # NOTE: The getOrCreate() call below may change settings of the active session which we do not
    # intend to do here. In particular, setting master to local[1] can break distributed clusters.
    # To avoid this problem, we explicitly check for an active session. This is not ideal but there
    # is no good workaround at the moment.
    import pyspark

    spark = pyspark.sql.SparkSession._instantiatedSession
    if spark is None:
        spark = pyspark.sql.SparkSession.builder.config("spark.python.worker.reuse", True) \
            .master("local[1]").getOrCreate()
    return _PyFuncModelWrapper(spark, _load_model(model_uri=path)) 
Example #3
Source File: session.py    From FATE with Apache License 2.0 6 votes vote down vote up
def parallelize(self,
                    data: Iterable,
                    name,
                    namespace,
                    partition,
                    include_key,
                    persistent,
                    chunk_size,
                    in_place_computing,
                    create_if_missing,
                    error_if_exist):
        _iter = data if include_key else enumerate(data)
        from pyspark import SparkContext
        rdd = SparkContext.getOrCreate().parallelize(_iter, partition)
        rdd = util.materialize(rdd)
        if namespace is None:
            namespace = self._session_id
        return RDDTable.from_rdd(rdd=rdd, job_id=self._session_id, namespace=namespace, name=name) 
Example #4
Source File: session.py    From FATE with Apache License 2.0 6 votes vote down vote up
def parallelize(self,
                    data: Iterable,
                    name,
                    namespace,
                    partition,
                    include_key,
                    persistent,
                    chunk_size,
                    in_place_computing,
                    create_if_missing,
                    error_if_exist):
        _iter = data if include_key else enumerate(data)
        from pyspark import SparkContext
        rdd = SparkContext.getOrCreate().parallelize(_iter, partition)
        rdd = util.materialize(rdd)
        if namespace is None:
            namespace = self._session_id
        return RDDTable.from_rdd(rdd=rdd, job_id=self._session_id, namespace=namespace, name=name) 
Example #5
Source File: test_spark.py    From snorkel with Apache License 2.0 6 votes vote down vote up
def test_lf_applier_spark_preprocessor_memoized(self) -> None:
        sc = SparkContext.getOrCreate()
        sql = SQLContext(sc)

        @preprocessor(memoize=True)
        def square_memoize(x: DataPoint) -> DataPoint:
            return Row(num=x.num, num_squared=x.num ** 2)

        @labeling_function(pre=[square_memoize])
        def fp_memoized(x: DataPoint) -> int:
            return 0 if x.num_squared > 42 else -1

        df = pd.DataFrame(dict(num=DATA))
        rdd = sql.createDataFrame(df).rdd
        applier = SparkLFApplier([f, fp_memoized])
        L = applier.apply(rdd)
        np.testing.assert_equal(L, L_PREPROCESS_EXPECTED) 
Example #6
Source File: test_spark.py    From snorkel with Apache License 2.0 5 votes vote down vote up
def test_lf_applier_spark_fault(self) -> None:
        sc = SparkContext.getOrCreate()
        sql = SQLContext(sc)
        df = pd.DataFrame(dict(num=DATA))
        rdd = sql.createDataFrame(df).rdd
        applier = SparkLFApplier([f, f_bad])
        with self.assertRaises(Exception):
            applier.apply(rdd)
        L = applier.apply(rdd, fault_tolerant=True)
        np.testing.assert_equal(L, L_EXPECTED_BAD) 
Example #7
Source File: spark.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def _conf(cls):
        from pyspark import SparkContext

        sc = SparkContext.getOrCreate()
        return sc._jsc.hadoopConfiguration() 
Example #8
Source File: _spark_autologging.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def _get_repl_id():
    """
    Get a unique REPL ID for a PythonSubscriber instance. This is used to distinguish between
    REPLs in multitenant, REPL-aware environments where multiple Python processes may share the
    same Spark JVM (e.g. in Databricks). In such environments, we pull the REPL ID from Spark
    local properties, and expect that the PythonSubscriber for the current Python process only
    receives events for datasource reads triggered by the current process.
    """
    repl_id = SparkContext.getOrCreate().getLocalProperty("spark.databricks.replId")
    if repl_id:
        return repl_id
    main_file = sys.argv[0] if len(sys.argv) > 0 else "<console>"
    return "PythonSubscriber[{filename}][{id}]".format(filename=main_file, id=uuid.uuid4().hex) 
Example #9
Source File: spark_tree_plotting.py    From spark-tree-plotting with MIT License 5 votes vote down vote up
def generate_tree_json(DecisionTreeClassificationModel, withNodeIDs=False):
    """
    Creates a recursive JSON string with the tree structure of a Spark's fitted
    DecisionTreeClassificationModel. Each node is defined by its attributes 
    (such as split rule, impurity...), and also its children, which are also 
    nodes with their own attributes and children (unitl a leaf node is reached).

    This JSON is meant to be used to plot the tree with whatever plotting
    library which is compatible with the JSON format.

    Arguments:
    DecisionTreeClassificationModel -- a pyspark.ml.classification.DecisionTreeClassificationModel
                                       instance
    
    withNodeIDs -- boolean that indicates whether to add IDs (increasing in breadth-first
                   fashion) to the tree nodes

    Returns:
    a JSON string with the formatted tree
    """
    sc = SparkContext.getOrCreate()

    json_tree = sc._jvm.com.vfive.spark.ml.SparkMLTree(DecisionTreeClassificationModel._java_obj).toJsonPlotFormat()

    if withNodeIDs:
        json_tree = dumps(add_node_ids(loads(json_tree, object_pairs_hook=OrderedDict)), indent=2)

    return json_tree 
Example #10
Source File: common.py    From LearningApacheSpark with MIT License 5 votes vote down vote up
def __init__(self, java_model):
        self._sc = SparkContext.getOrCreate()
        self._java_model = java_model 
Example #11
Source File: common.py    From LearningApacheSpark with MIT License 5 votes vote down vote up
def callMLlibFunc(name, *args):
    """ Call API in PythonMLLibAPI """
    sc = SparkContext.getOrCreate()
    api = getattr(sc._jvm.PythonMLLibAPI(), name)
    return callJavaFunc(sc, api, *args) 
Example #12
Source File: common.py    From LearningApacheSpark with MIT License 5 votes vote down vote up
def _java2py(sc, r, encoding="bytes"):
    if isinstance(r, JavaObject):
        clsName = r.getClass().getSimpleName()
        # convert RDD into JavaRDD
        if clsName != 'JavaRDD' and clsName.endswith("RDD"):
            r = r.toJavaRDD()
            clsName = 'JavaRDD'

        if clsName == 'JavaRDD':
            jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r)
            return RDD(jrdd, sc)

        if clsName == 'Dataset':
            return DataFrame(r, SQLContext.getOrCreate(sc))

        if clsName in _picklable_classes:
            r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
        elif isinstance(r, (JavaArray, JavaList)):
            try:
                r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
            except Py4JJavaError:
                pass  # not pickable

    if isinstance(r, (bytearray, bytes)):
        r = PickleSerializer().loads(bytes(r), encoding=encoding)
    return r 
Example #13
Source File: test_spark.py    From snorkel with Apache License 2.0 5 votes vote down vote up
def test_lf_applier_spark_preprocessor(self) -> None:
        sc = SparkContext.getOrCreate()
        sql = SQLContext(sc)
        df = pd.DataFrame(dict(num=DATA))
        rdd = sql.createDataFrame(df).rdd
        applier = SparkLFApplier([f, fp])
        L = applier.apply(rdd)
        np.testing.assert_equal(L, L_PREPROCESS_EXPECTED) 
Example #14
Source File: wikidata.py    From qb with MIT License 5 votes vote down vote up
def parse_raw_wikidata(output):
    spark_conf = SparkConf().setAppName('QB Wikidata').setMaster(QB_SPARK_MASTER)
    sc = SparkContext.getOrCreate(spark_conf)  # type: SparkContext

    wikidata = sc.textFile('s3a://entilzha-us-west-2/wikidata/wikidata-20170306-all.json')

    def parse_line(line):
        if len(line) == 0:
            return []
        if line[0] == '[' or line[0] == ']':
            return []
        elif line.endswith(','):
            return [json.loads(line[:-1])]
        else:
            return [json.loads(line)]

    parsed_wikidata = wikidata.flatMap(parse_line).cache()
    property_map = extract_property_map(parsed_wikidata)
    b_property_map = sc.broadcast(property_map)

    wikidata_items = parsed_wikidata.filter(lambda d: d['type'] == 'item').cache()
    parsed_wikidata.unpersist()
    item_page_map = extract_item_page_map(wikidata_items)
    b_item_page_map = sc.broadcast(item_page_map)

    parsed_item_map = extract_items(wikidata_items, b_property_map, b_item_page_map)

    with open(output, 'wb') as f:
        pickle.dump({
            'parsed_item_map': parsed_item_map,
            'item_page_map': item_page_map,
            'property_map': property_map
        }, f)

    sc.stop() 
Example #15
Source File: test_spark.py    From snorkel with Apache License 2.0 5 votes vote down vote up
def test_lf_applier_spark(self) -> None:
        sc = SparkContext.getOrCreate()
        sql = SQLContext(sc)
        df = pd.DataFrame(dict(num=DATA))
        rdd = sql.createDataFrame(df).rdd
        applier = SparkLFApplier([f, g])
        L = applier.apply(rdd)
        np.testing.assert_equal(L, L_EXPECTED) 
Example #16
Source File: table.py    From FATE with Apache License 2.0 5 votes vote down vote up
def _rdd_from_dtable(self):
        storage_iterator = self._dtable.collect(use_serialize=True)
        if self._dtable.count() <= 0:
            storage_iterator = []

        num_partition = self._dtable._partitions
        from pyspark import SparkContext
        self._rdd = SparkContext.getOrCreate() \
            .parallelize(storage_iterator, num_partition) \
            .persist(util.get_storage_level())
        return self._rdd 
Example #17
Source File: util.py    From FATE with Apache License 2.0 5 votes vote down vote up
def broadcast_eggroll_session(work_mode, eggroll_session):
    import pickle
    pickled_client = pickle.dumps((work_mode.value, eggroll_session)).hex()
    from pyspark import SparkContext
    SparkContext.getOrCreate().setLocalProperty(_EGGROLL_CLIENT, pickled_client)


# noinspection PyProtectedMember,PyUnresolvedReferences 
Example #18
Source File: table.py    From FATE with Apache License 2.0 5 votes vote down vote up
def _rdd_from_dtable(self):
        storage_iterator = self._dtable.get_all()
        if self._dtable.count() <= 0:
            storage_iterator = []

        num_partition = self._dtable.get_partitions()

        from pyspark import SparkContext
        self._rdd = SparkContext.getOrCreate() \
            .parallelize(storage_iterator, num_partition) \
            .persist(util.get_storage_level())
        return self._rdd 
Example #19
Source File: test_spark.py    From sentry-python with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def test_start_sentry_listener():
    spark_context = SparkContext.getOrCreate()

    gateway = spark_context._gateway
    assert gateway._callback_server is None

    _start_sentry_listener(spark_context)

    assert gateway._callback_server is not None 
Example #20
Source File: imageIO.py    From spark-deep-learning with Apache License 2.0 5 votes vote down vote up
def readImagesWithCustomFn(path, decode_f, numPartition=None):
    """
    Read a directory of images (or a single image) into a DataFrame using a custom library to
    decode the images.

    :param path: str, file path.
    :param decode_f: function to decode the raw bytes into an array compatible with one of the
        supported OpenCv modes. see @imageIO.PIL_decode for an example.
    :param numPartition: [optional] int, number or partitions to use for reading files.
    :return: DataFrame with schema == ImageSchema.imageSchema.
    """
    warnings.warn("readImagesWithCustomFn() will be removed in the next release of sparkdl. "
                  "Please use pillow and Pandas UDF instead.", DeprecationWarning)
    return _readImagesWithCustomFn(path, decode_f, numPartition, sc=SparkContext.getOrCreate()) 
Example #21
Source File: named_image.py    From spark-deep-learning with Apache License 2.0 5 votes vote down vote up
def _getScaleHintList():
    featurizer = SparkContext.getOrCreate()._jvm.com.databricks.sparkdl.DeepImageFeaturizer
    if isinstance(featurizer, py4j.java_gateway.JavaPackage):
        # do not see DeepImageFeaturizer, possibly running without spark
        # instead of failing return empty list
        return []
    return dict(featurizer.scaleHintsJava()).keys() 
Example #22
Source File: spark.py    From qb with MIT License 5 votes vote down vote up
def create_spark_session(app_name='Quiz Bowl', configs=None) -> SparkSession:
    create_spark_context(app_name=app_name, configs=configs)
    return SparkSession.builder.getOrCreate() 
Example #23
Source File: jpredDataset.py    From mmtf-pyspark with Apache License 2.0 4 votes vote down vote up
def get_dataset():
    '''Gets JPred 4/JNet (v.2.3.1) secondary structure dataset.

    Returns
    -------
    dataset
       secondaryStructure dataset
    '''

    URL = "http://www.compbio.dundee.ac.uk/jpred/downloads/retr231.tar.gz"
    instream = urllib.request.urlopen(URL)
    secondaryStructures, sequences, trained = {}, {}, {}
    scopIds = set()
    res = []

    with tarfile.open(fileobj=instream, mode="r:gz") as tf:

        for entry in tf:
            if entry.isdir():
                continue
            br = tf.extractfile(entry)

            if ".dssp" in entry.name:
                scopID = str(br.readline())[3:-3]  # Remove newline and byte
                secondaryStructure = str(br.readline())[2:-3]  # Remove newline and byte
                secondaryStructure = secondaryStructure.replace('-', 'C')
                secondaryStructures[scopID] = secondaryStructure

            if ".fasta" in entry.name:
                scopID = str(br.readline())[3:-3]  # Remove newline and byte
                sequence = str(br.readline())[2:-3]  # Remove newline and byte
                scopIds.add(scopID)
                sequences[scopID] = sequence

                if "training/" in entry.name:
                    trained[scopID] = "true"
                elif "blind/" in entry.name:
                    trained[scopID] = "false"

    for scopId in scopIds:
        row = Row(scopId, sequences[scopId],
                  secondaryStructures[scopId], trained[scopId])
        res.append(row)

    sc = SparkContext.getOrCreate()
    data = sc.parallelize(res)
    colNames = ["scopID", "sequence", "secondaryStructure", "trained"]

    return pythonRDDToDataset.get_dataset(data, colNames) 
Example #24
Source File: _spark_autologging.py    From mlflow with Apache License 2.0 4 votes vote down vote up
def autolog():
    """Implementation of Spark datasource autologging"""
    global _spark_table_info_listener
    if _get_current_listener() is None:
        active_session = _get_active_spark_session()
        if active_session is None:
            raise MlflowException(
                "No active SparkContext found, refusing to enable Spark datasource "
                "autologging. Please create a SparkSession e.g. via "
                "SparkSession.builder.getOrCreate() (see API docs at "
                "https://spark.apache.org/docs/latest/api/python/"
                "pyspark.sql.html#pyspark.sql.SparkSession) "
                "before attempting to enable autologging")
        # We know SparkContext exists here already, so get it
        sc = SparkContext.getOrCreate()
        if _get_spark_major_version(sc) < 3:
            raise MlflowException(
                "Spark autologging unsupported for Spark versions < 3")
        gw = active_session.sparkContext._gateway
        params = gw.callback_server_parameters
        callback_server_params = CallbackServerParameters(
            address=params.address, port=params.port, daemonize=True, daemonize_connections=True,
            eager_load=params.eager_load, ssl_context=params.ssl_context,
            accept_timeout=params.accept_timeout, read_timeout=params.read_timeout,
            auth_token=params.auth_token)
        gw.start_callback_server(callback_server_params)

        event_publisher = _get_jvm_event_publisher()
        try:
            event_publisher.init(1)
            _spark_table_info_listener = PythonSubscriber()
            _spark_table_info_listener.register()
        except Exception as e:
            raise MlflowException("Exception while attempting to initialize JVM-side state for "
                                  "Spark datasource autologging. Please ensure you have the "
                                  "mlflow-spark JAR attached to your Spark session as described "
                                  "in http://mlflow.org/docs/latest/tracking.html#"
                                  "automatic-logging-from-spark-experimental. Exception:\n%s"
                                  % e)

        # Register context provider for Spark autologging
        from mlflow.tracking.context.registry import _run_context_provider_registry
        _run_context_provider_registry.register(SparkAutologgingContext) 
Example #25
Source File: spark.py    From mlflow with Apache License 2.0 4 votes vote down vote up
def autolog():
    """
    Enables automatic logging of Spark datasource paths, versions (if applicable), and formats
    when they are read. This method is not threadsafe and assumes a
    `SparkSession
    <https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.SparkSession>`_
    already exists with the
    `mlflow-spark JAR
    <http://mlflow.org/docs/latest/tracking.html#automatic-logging-from-spark-experimental>`_
    attached. It should be called on the Spark driver, not on the executors (i.e. do not call
    this method within a function parallelized by Spark). This API requires Spark 3.0 or above.

    Datasource information is logged under the current active MLflow run. If no active run
    exists, datasource information is cached in memory & logged to the next-created active run
    (but not to successive runs). Note that autologging of Spark ML (MLlib) models is not currently
    supported via this API. Datasource-autologging is best-effort, meaning that if Spark is under
    heavy load or MLflow logging fails for any reason (e.g., if the MLflow server is unavailable),
    logging may be dropped.

    For any unexpected issues with autologging, check Spark driver and executor logs in addition
    to stderr & stdout generated from your MLflow code - datasource information is pulled from
    Spark, so logs relevant to debugging may show up amongst the Spark logs.

    .. code-block:: python
        :caption: Example

        import mlflow.spark
        from pyspark.sql import SparkSession
        # Create and persist some dummy data
        spark = (SparkSession.builder
                    .config("spark.jars.packages", "org.mlflow.mlflow-spark")
                    .getOrCreate())
        df = spark.createDataFrame([
                (4, "spark i j k"),
                (5, "l m n"),
                (6, "spark hadoop spark"),
                (7, "apache hadoop")], ["id", "text"])
        import tempfile
        tempdir = tempfile.mkdtemp()
        df.write.format("csv").save(tempdir)
        # Enable Spark datasource autologging.
        mlflow.spark.autolog()
        loaded_df = spark.read.format("csv").load(tempdir)
        # Call collect() to trigger a read of the Spark datasource. Datasource info
        # (path and format)is automatically logged to an MLflow run.
        loaded_df.collect()
        shutil.rmtree(tempdir) # clean up tempdir
    """
    from mlflow import _spark_autologging
    _spark_autologging.autolog()