Python pyspark.sql.functions.col() Examples

The following are 30 code examples of pyspark.sql.functions.col(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pyspark.sql.functions , or try the search function .
Example #1
Source File: helpers.py    From SMV with Apache License 2.0 6 votes vote down vote up
def smvTimestampToStr(self, timezone, fmt):
        """Build a string from a timestamp and timezone

            Args:
                timezone (string or Column): the timezone follows the rules in 
                    https://www.joda.org/joda-time/apidocs/org/joda/time/DateTimeZone.html#forID-java.lang.String-
                    It can be a string like "America/Los_Angeles" or "+1000". If it is null, use current system time zone.
                fmt (string): the format is the same as the Java `Date` format

            Example:
                >>> df.select(col("ts").smvTimestampToStr("America/Los_Angeles","yyyy-MM-dd HH:mm:ss"))

            Returns:
                (Column): StringType. The converted String with given format
        """
        if is_string(timezone):
            jtimezone = timezone
        elif isinstance(timezone, Column):
            jtimezone = timezone._jc
        else:
            raise RuntimeError("timezone parameter must be either an string or a Column")
        jc = self._jColumnHelper.smvTimestampToStr(jtimezone, fmt)
        return Column(jc) 
Example #2
Source File: transform.py    From search-MjoLniR with MIT License 6 votes vote down vote up
def for_each_item(
    col_name: str,
    items: List[_LT],
    transformer_factory: Callable[[_LT], Transformer],
    mapper=map
) -> Transformer:
    """Run a transformation for each value in a list of values"""
    # A lambda inside the list comprehension would capture `item`
    # by name, use a proper function to ensure item is captured
    # from a unique context.
    def restrict_to_item(item: _LT) -> Transformer:
        return lambda df: df.where(F.col(col_name) == item)

    transformers = [seq_transform([
        restrict_to_item(item),
        transformer_factory(item)
    ]) for item in items]

    return par_transform(transformers, mapper)


# Shared transformations 
Example #3
Source File: tuning.py    From search-MjoLniR with MIT License 6 votes vote down vote up
def group_k_fold(df, num_folds, output_column='fold'):
    """
    Generates group k-fold splits. The fold a row belongs to is
    assigned to the column identified by the output_column parameter.

    Parameters
    ----------
    df : pyspark.sql.DataFrame
    num_folds : int
    output_column : str, optional

    Returns
    ------
    pyspark.sql.DataFrame
        Input data frame with a 'fold' column indicating fold membership.
        Normalized queries are equally distributed to each fold.
    """
    return (
        split(df, [1. / num_folds] * num_folds, output_column)
        .withColumn(output_column, mjolnir.spark.add_meta(df._sc, F.col(output_column), {
            'num_folds': num_folds,
        }))) 
Example #4
Source File: test_tuning.py    From search-MjoLniR with MIT License 6 votes vote down vote up
def test_split(spark):
    df = (
        spark
        .range(1, 100 * 100)
        # convert into 100 "queries" with 100 values each. We need a
        # sufficiently large number of queries, or the split wont have
        # enough data for partitions to even out.
        .select(F.lit('foowiki').alias('wikiid'),
                (F.col('id')/100).cast('int').alias('norm_query_id')))

    with_folds = mjolnir.training.tuning.split(df, (0.8, 0.2)).collect()

    fold_0 = [row for row in with_folds if row.fold == 0]
    fold_1 = [row for row in with_folds if row.fold == 1]

    # Check the folds are pretty close to requested
    total_len = float(len(with_folds))
    assert 0.8 == pytest.approx(len(fold_0) / total_len, abs=0.015)
    assert 0.2 == pytest.approx(len(fold_1) / total_len, abs=0.015)

    # Check each norm query is only found on one side of the split
    queries_in_0 = set([row.norm_query_id for row in fold_0])
    queries_in_1 = set([row.norm_query_id for row in fold_1])
    assert len(queries_in_0.intersection(queries_in_1)) == 0 
Example #5
Source File: norm_query_clustering.py    From search-MjoLniR with MIT License 6 votes vote down vote up
def filter_min_sessions_per_norm_query(min_sessions: int) -> mt.Transformer:
    def transform(df: DataFrame) -> DataFrame:
        w = Window.partitionBy('wikiid', 'norm_query')
        return (
            df.withColumn(
                'has_min_sessions',
                at_least_n_distinct('session_id', min_sessions).over(w))
            .where(F.col('has_min_sessions'))
            .drop('has_min_sessions'))
    return transform 
Example #6
Source File: ProteinChainClassification.ipynb.py    From mmtf-pyspark with Apache License 2.0 6 votes vote down vote up
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType",                            when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha").                            when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta").                            when((col("alpha") > maxThreshold) & (col("beta") > maxThreshold), "alpha+beta").                            otherwise("other")                           )


# ## Classify chains by secondary structure type

# In[22]: 
Example #7
Source File: feature_selection.py    From search-MjoLniR with MIT License 6 votes vote down vote up
def select_features(
    wiki: str,
    num_features: int,
    metadata: Dict
) -> mt.Transformer:
    def transform(df: DataFrame) -> DataFrame:
        # Compute the "best" features, per some metric
        sc = df.sql_ctx.sparkSession.sparkContext
        features = metadata['input_feature_meta']['features']
        selected = mjolnir.feature_engineering.select_features(
            sc, df, features, num_features, algo='mrmr')
        metadata['wiki_features'][wiki] = selected

        # Rebuild the `features` col with only the selected features
        keep_cols = metadata['default_cols'] + selected
        df_selected = df.select(*keep_cols)
        assembler = VectorAssembler(
            inputCols=selected, outputCol='features')
        return assembler.transform(df_selected).drop(*selected)
    return transform 
Example #8
Source File: transform.py    From search-MjoLniR with MIT License 6 votes vote down vote up
def cache_to_disk(temp_dir: str, partition_by: str) -> Transformer:
    """Write a dataframe to disk partitioned by a column.

    Writes out the source dataframe partitioned by the provided
    column. The intention is for downstream tasks to construct
    a dataframe per partitioned value. When doing so this allows
    the downstream data frames to read individual columns for specific
    wikis from disk directly.

    Cleaning up the temp_dir is the callers responsibility and must
    be done after all transformations have executed to completion,
    likely after closing the SparkContext.

    TODO: This emits the same number of partitions for each partition col,
    while some may need 1 partition and others 1000. We would need count
    estimates to do that partitioning though.
    """
    def transform(df: DataFrame) -> DataFrame:
        df.write.partitionBy(partition_by).parquet(temp_dir)
        return df.sql_ctx.read.parquet(temp_dir)
    return transform 
Example #9
Source File: testColumnHelper.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_smvDateTimeFunctions(self):
        df = self.createDF("k:Timestamp[yyyyMMdd]; v:String;", "20190101,a;,b")
        res = df.select(col("k").smvYear(), col("k").smvMonth(), col("k").smvQuarter(), col("k").smvDayOfMonth(), col("k").smvDayOfWeek(), col("k").smvHour())
        expected = self.createDF("SmvYear(k): Integer; SmvMonth(k): Integer; SmvQuarter(k): Integer; SmvDayOfMonth(k): Integer; SmvDayOfWeek(k): Integer; SmvHour(k): Integer", "2019,1,1,1,2,0;" + ",,,,,")

        if sys.version < '3':
            self.should_be_same(expected, res)
        else:
            # Python 3 is a bit picky about null ordering
            self.assertEquals(expected.columns, res.columns)
            a = expected.collect()
            b = res.collect()
            try: a.sort()
            except TypeError: pass
            try: b.sort()
            except TypeError: pass
            self.assertEqual(a, b) 
Example #10
Source File: testSmvGroupedData.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_smvTimePanelAgg_with_Week(self):
        df = self.createDF("k:Integer; ts:String; v:Double",
                 "1,20120301,1.5;" +
                 "1,20120304,4.5;" +
                 "1,20120308,7.5;" +
                 "1,20120309,2.45"
             ).withColumn("ts", col('ts').smvStrToTimestamp("yyyyMMdd"))

        import smv.panel as p

        res = df.smvGroupBy('k').smvTimePanelAgg(
            'ts', p.Week(2012, 3, 1), p.Week(2012, 3, 10)
        )(
            sum('v').alias('v')
        )

        expect = self.createDF("k: Integer;smvTime: String;v: Double",
            """1,W20120305,9.95;
                1,W20120227,6.0""")

        self.should_be_same(res, expect) 
Example #11
Source File: testSmvGroupedData.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_smvTimePanelAgg(self):
        df = self.createDF("k:Integer; ts:String; v:Double",
            """1,20120101,1.5;
                1,20120301,4.5;
                1,20120701,7.5;
                1,20120501,2.45"""
            ).withColumn("ts", col('ts').smvStrToTimestamp("yyyyMMdd"))

        import smv.panel as p

        res = df.smvGroupBy('k').smvTimePanelAgg(
            'ts', p.Quarter(2012,1), p.Quarter(2012,2)
        )(
            sum('v').alias('v')
        )

        expect = self.createDF("k: Integer;smvTime: String;v: Double",
                """1,Q201201,6.0;
                    1,Q201202,2.45""")

        self.should_be_same(expect, res) 
Example #12
Source File: utils.py    From spylon with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def wrap_function_cols(self, name, package_name=None, object_name=None, java_class_instance=None, doc=""):
        """Utility method for wrapping a scala/java function that returns a spark sql Column.

        This assumes that the function that you are wrapping takes a list of spark sql Column objects as its arguments.
        """
        def _(*cols):
            jcontainer = self.get_java_container(package_name=package_name, object_name=object_name, java_class_instance=java_class_instance)
            # Ensure that your argument is a column
            col_args = [col._jc if isinstance(col, Column) else _make_col(col)._jc for col in cols]
            function = getattr(jcontainer, name)
            args = col_args
            jc = function(*args)
            return Column(jc)
        _.__name__ = name
        _.__doc__ = doc
        return _ 
Example #13
Source File: testSmvGroupedData.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_smvWithTimePanel(self):
        df = self.createDF("k:Integer; ts:String; v:Double",
            """1,20120101,1.5;
                1,20120301,4.5;
                1,20120701,7.5;
                1,20120501,2.45"""
            ).withColumn("ts", col('ts').smvStrToTimestamp("yyyyMMdd"))

        import smv.panel as p

        res = df.smvGroupBy('k').smvWithTimePanel(
            'ts', p.Month(2012,1), p.Month(2012,3)
        )

        expect = self.createDF("k: Integer;ts: String;v: Double;smvTime: String",
        """1,,,M201202;
            1,,,M201201;
            1,,,M201203;
            1,20120101,1.5,M201201;
            1,20120301,4.5,M201203""").withColumn("ts", col('ts').smvStrToTimestamp("yyyyMMdd"))

        self.should_be_same(expect, res) 
Example #14
Source File: WaterInteractionsExample.ipynb.py    From mmtf-pyspark with Apache License 2.0 6 votes vote down vote up
def filter_bridging_water_interactions(data, maxInteractions):
    if maxInteractions == 2:
        data = data.filter((col("type1") == "LGO") |                            (col("type2") == "LGO"))
        data = data.filter((col("type1") == "PRO") |                            (col("type2") == "PRO"))
    elif maxInteractions == 3:
        data = data.filter((col("type1") == "LGO") |                            (col("type2") == "LGO") |                            (col("type3") == "LGO"))
        data = data.filter((col("type1") == "PRO") |                            (col("type2") == "PRO") |                            (col("type3") == "PRO"))
    elif maxInteractions == 4:
        data = data.filter((col("type1") == "LGO") |                            (col("type2") == "LGO") |                            (col("type3") == "LGO") |                            (col("type4") == "LGO"))
        data = data.filter((col("type1") == "PRO") |                            (col("type2") == "PRO") |                            (col("type3") == "PRO") |                            (col("type4") == "PRO"))
    else:
        raise ValueError("maxInteractions > 4 are not supported yet")
    return data


# ## Keep only interactions with at least one organic ligand and one protein interaction

# In[8]: 
Example #15
Source File: helpers.py    From SMV with Apache License 2.0 6 votes vote down vote up
def smvPlusYears(self, delta):
        """Add N years to `Timestamp` or `Date` column

            Args:
                delta (int or Column): the number of years to add

            Example:
                >>> df.select(col("dob").smvPlusYears(3))

            Returns:
                (Column): TimestampType. The incremented Timestamp, or null if input is null.
                    **Note** even if the input is DateType, the output is TimestampType
        """
        if (isinstance(delta, int)):
            jdelta = delta
        elif (isinstance(delta, Column)):
            jdelta = delta._jc
        else:
            raise RuntimeError("delta parameter must be either an int or a Column")
        jc = self._jColumnHelper.smvPlusYears(jdelta)
        return Column(jc) 
Example #16
Source File: testDataFrameHelper.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_smvDedupByKey_with_column(self):
        schema = "a:Integer; b:Double; c:String"
        df = self.createDF(
            schema,
            """1,2.0,hello;
            1,3.0,hello;
            2,10.0,hello2;
            2,11.0,hello3"""
        )
        r1 = df.smvDedupByKey(col("a"))
        expect = self.createDF(
            schema,
            """1,2.0,hello;
            2,10.0,hello2"""
        )
        self.should_be_same(expect, r1) 
Example #17
Source File: helpers.py    From SMV with Apache License 2.0 6 votes vote down vote up
def smvPlusWeeks(self, delta):
        """Add N weeks to `Timestamp` or `Date` column

            Args:
                delta (int or Column): the number of weeks to add

            Example:
                >>> df.select(col("dob").smvPlusWeeks(3))

            Returns:
                (Column): TimestampType. The incremented Timestamp, or null if input is null.
                    **Note** even if the input is DateType, the output is TimestampType
        """
        if (isinstance(delta, int)):
            jdelta = delta
        elif (isinstance(delta, Column)):
            jdelta = delta._jc
        else:
            raise RuntimeError("delta parameter must be either an int or a Column")
        jc = self._jColumnHelper.smvPlusWeeks(jdelta)
        return Column(jc) 
Example #18
Source File: testDataFrameHelper.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_smvDedupByKeyWithOrder_with_string(self):
        schema = "a:Integer; b:Double; c:String"
        df = self.createDF(
            schema,
            """1,2.0,hello;
            1,3.0,hello;
            2,10.0,hello2;
            2,11.0,hello3"""
        )
        r1 = df.smvDedupByKeyWithOrder("a")(col("b").desc())
        expect = self.createDF(
            schema,
            """1,3.0,hello;
            2,11.0,hello3"""
        )
        self.should_be_same(expect, r1) 
Example #19
Source File: test_imageIO.py    From spark-deep-learning with Apache License 2.0 6 votes vote down vote up
def test_readImages(self):
        # Test that reading
        imageDF = imageIO._readImagesWithCustomFn(
            "file/path", decode_f=imageIO.PIL_decode, numPartition=2, sc=self.binaryFilesMock)
        self.assertTrue("image" in imageDF.schema.names)

        # The DF should have 2 images and 1 null.
        self.assertEqual(imageDF.count(), 3)
        validImages = imageDF.filter(col("image").isNotNull())
        self.assertEqual(validImages.count(), 2)

        img = validImages.first().image
        self.assertEqual(img.height, array.shape[0])
        self.assertEqual(img.width, array.shape[1])
        self.assertEqual(imageIO.imageTypeByOrdinal(img.mode).nChannels, array.shape[2])
        # array comes out of PIL and is in RGB order
        self.assertEqual(img.data, array.tobytes()) 
Example #20
Source File: testDataFrameHelper.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_smvDedupByKeyWithOrder_with_column(self):
        schema = "a:Integer; b:Double; c:String"
        df = self.createDF(
            schema,
            """1,2.0,hello;
            1,3.0,hello;
            2,10.0,hello2;
            2,11.0,hello3"""
        )
        r1 = df.smvDedupByKeyWithOrder(col("a"))(col("b").desc())
        expect = self.createDF(
            schema,
            """1,3.0,hello;
            2,11.0,hello3"""
        )
        self.should_be_same(expect, r1) 
Example #21
Source File: helpers.py    From SMV with Apache License 2.0 6 votes vote down vote up
def smvDupeCheck(self, keys, n=10000):
        """For a given list of potential keys, check for duplicated records with the number of duplications and all the columns.

            Null values are allowed in the potential keys, so duplication on Null valued keys will also be reported.

            Args:
                keys (list(string)): the key column list which the duplicate check applied
                n (integer): number of rows from input data for checking duplications, defaults to 10000

            Returns:
                (DataFrame): returns key columns + "_N" + the rest columns for the records with more key duplication records, 
                    where "_N" has the count of duplications of the key values of that record
        """
        dfTopN = self.df.limit(n).cache()

        res = dfTopN.groupBy(*keys)\
            .agg(F.count(F.lit(1)).alias('_N'))\
            .where(F.col('_N') > 1)\
            .smvJoinByKey(dfTopN, keys, 'inner', True)\
            .orderBy(*keys)

        dfTopN.unpersist()
        return res 
Example #22
Source File: helpers.py    From SMV with Apache License 2.0 6 votes vote down vote up
def smvSelectPlus(self, *cols):
        """Selects all the current columns in current DataFrame plus the supplied expressions

            The new columns are added to the end of the current column list.

            Args:
                cols (\*Column): expressions to add to the DataFrame

            Example:
                >>> df.smvSelectPlus((col("price") * col("count")).alias("amt"))

            Returns:
                (DataFrame): the resulting DataFrame after removal of columns
        """
        jdf = self._jDfHelper.smvSelectPlus(_to_seq(cols, _jcol))
        return DataFrame(jdf, self._sql_ctx) 
Example #23
Source File: helpers.py    From SMV with Apache License 2.0 6 votes vote down vote up
def topNValsByFreq(self, n, col):
        """Get top N most frequent values in Column col

            Args:
                n (int): maximum number of values
                col (Column): which column to get values from

            Example:

                >>> df.topNValsByFreq(1, col("cid"))

                will return the single most frequent value in the cid column

            Returns:
                (list(object)): most frequent values (type depends on schema)
        """
        topNdf = DataFrame(self._jDfHelper._topNValsByFreq(n, col._jc), self._sql_ctx)
        return [list(r)[0] for r in topNdf.collect()] 
Example #24
Source File: helpers.py    From SMV with Apache License 2.0 6 votes vote down vote up
def smvTopNRecs(self, maxElems, *cols):
        """For each group, return the top N records according to a given ordering

            Example:

                >>> df.smvGroupBy("id").smvTopNRecs(3, col("amt").desc())

                This will keep the 3 largest amt records for each id

            Args:
                maxElems (int): maximum number of records per group
                cols (\*str): columns defining the ordering

            Returns:
                (DataFrame): result of taking top records from groups

        """
        return DataFrame(self.sgd.smvTopNRecs(maxElems, smv_copy_array(self.df._sc, *cols)), self.df.sql_ctx) 
Example #25
Source File: testDataFrameHelper.py    From SMV with Apache License 2.0 5 votes vote down vote up
def test_smvGetDesc(self):
        df = self.createDF("a:String", "a")
        res = df.smvDesc(("a", "this is col a"))
        self.assertEqual(res.smvGetDesc("a"), "this is col a")
        self.assertEqual(res.smvGetDesc(), [("a", "this is col a")]) 
Example #26
Source File: testDataFrameHelper.py    From SMV with Apache License 2.0 5 votes vote down vote up
def test_smvSelectMinus_with_column(self):
        schema = "k:String;v1:Integer;v2:Integer"
        df = self.createDF(schema, "a,1,2;b,2,3")
        r1 = df.smvSelectMinus(col("v1"))
        expect = self.createDF("k:String;v2:Integer", "a,2;b,3")
        self.should_be_same(expect, r1) 
Example #27
Source File: testDataFrameHelper.py    From SMV with Apache License 2.0 5 votes vote down vote up
def test_smvRenameField_preserve_meta_for_renamed_fields(self):
        df = self.createDF("a:Integer; b:String", "1,abc;1,def;2,ghij")
        desc = "c description"
        res1 = df.groupBy(col("a")).agg(count(col("a")).alias("c"))\
                 .smvDesc(("c", desc))
        self.assertEqual(res1.smvGetDesc(), [("a", ""), ("c", desc)])

        res2 = res1.smvRenameField(("c", "d"))
        self.assertEqual(res2.smvGetDesc(), [("a", ""), ("d", desc)]) 
Example #28
Source File: testDataFrameHelper.py    From SMV with Apache License 2.0 5 votes vote down vote up
def test_smvExpandStruct(self):
        schema = "id:String;a:Double;b:Double"
        df1 = self.createDF(schema, "a,1.0,10.0;a,2.0,20.0;b,3.0,30.0")
        df2 = df1.select(col("id"), struct("a", "b").alias("c"))
        res = df2.smvExpandStruct("c")
        expect = self.createDF(schema, "a,1.0,10.0;a,2.0,20.0;b,3.0,30.0")
        self.should_be_same(expect, res) 
Example #29
Source File: employment.py    From SMV with Apache License 2.0 5 votes vote down vote up
def run(self, i):
        df = i[Employment]
        return df.groupBy(F.col("ST")).agg(F.sum(F.col("EMP")).alias("EMP")) 
Example #30
Source File: helpers.py    From SMV with Apache License 2.0 5 votes vote down vote up
def smvHour(self):
        """Extract hour component from a timestamp

            Example:
                >>> df.select(col("dob").smvHour())

            Returns:
                (Column): IntegerType. Hour component as integer, or null if input column is null
        """
        jc = self._jColumnHelper.smvHour()
        return Column(jc)