Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Create column using Spark pandas_udf, with dynamic number of input columns

I have this df:

df = spark.createDataFrame(
    [('row_a', 5.0, 0.0, 11.0),
     ('row_b', 3394.0, 0.0, 4543.0),
     ('row_c', 136111.0, 0.0, 219255.0),
     ('row_d', 0.0, 0.0, 0.0),
     ('row_e', 0.0, 0.0, 0.0),
     ('row_f', 42.0, 0.0, 54.0)],
    ['value', 'col_a', 'col_b', 'col_c']
)

I would like to use .quantile(0.25, axis=1) from Pandas which would add one column:

import pandas as pd
pdf = df.toPandas()
pdf['25%'] = pdf.quantile(0.25, axis=1)
print(pdf)
#    value     col_a  col_b     col_c      25%
# 0  row_a       5.0    0.0      11.0      2.5
# 1  row_b    3394.0    0.0    4543.0   1697.0
# 2  row_c  136111.0    0.0  219255.0  68055.5
# 3  row_d       0.0    0.0       0.0      0.0
# 4  row_e       0.0    0.0       0.0      0.0
# 5  row_f      42.0    0.0      54.0     21.0

Performance to me is important, so I assume pandas_udf from pyspark.sql.functions could do it in a more optimized way. But I struggle to make a performant and useful function. This is my best attempt:

from pyspark.sql import functions as F
import pandas as pd
@F.pandas_udf('double')
def quartile1_on_axis1(a: pd.Series, b: pd.Series, c: pd.Series) -> pd.Series:
    pdf = pd.DataFrame({'a':a, 'b':b, 'c':c})
    return pdf.quantile(0.25, axis=1)

df = df.withColumn('25%', quartile1_on_axis1('col_a', 'col_b', 'col_c'))
  1. I don't like that I need an argument for every column and later in the function addressing those arguments separately to create a df. All of those columns serve the same purpose, so IMHO there should be a way to address them all together, something like in this pseudocode:

    def quartile1_on_axis1(*cols) -> pd.Series:
        pdf = pd.DataFrame(cols)
    

    This way I could use this function for any number of columns.

  2. Is it necessary to create a pd.Dataframe inside the UDF? To me this seems the same as without UDF (Spark df -> Pandas df -> Spark df), as shown above. Without UDF it's even shorter. Should I really try to make it work with pandas_udf performance-wise? I think pandas_udf was designed specifically for this kind of purpose...

like image 210
ZygD Avatar asked Dec 07 '25 10:12

ZygD


2 Answers

You can pass a single struct column instead of using multiple columns like this:

@F.pandas_udf('double')
def quartile1_on_axis1(s: pd.DataFrame) -> pd.Series:
    return s.quantile(0.25, axis=1)


cols = ['col_a', 'col_b', 'col_c']

df = df.withColumn('25%', quartile1_on_axis1(F.struct(*cols)))
df.show()

# +-----+--------+-----+--------+-------+
# |value|   col_a|col_b|   col_c|    25%|
# +-----+--------+-----+--------+-------+
# |row_a|     5.0|  0.0|    11.0|    2.5|
# |row_b|  3394.0|  0.0|  4543.0| 1697.0|
# |row_c|136111.0|  0.0|219255.0|68055.5|
# |row_d|     0.0|  0.0|     0.0|    0.0|
# |row_e|     0.0|  0.0|     0.0|    0.0|
# |row_f|    42.0|  0.0|    54.0|   21.0|
# +-----+--------+-----+--------+-------+

pyspark.sql.functions.pandas_udf

Note that the type hint should use pandas.Series in all cases but there is one variant that pandas.DataFrame should be used for its input or output type hint instead when the input or output column is of pyspark.sql.types.StructType.

like image 198
blackbishop Avatar answered Dec 09 '25 01:12

blackbishop


The udf approach will get you the result you need, and is definitely the most straightforward. However, if performance really is top priority you can create your own native Spark implementation for quantile. The basics can be coded quite easily, if you want to use any of the other pandas parameters you'll need to tweak it yourself.

Note: this is taken from the pandas API docs for interpolation='linear'. If you intent to use it, please test the performance and verify the results yourself on large datasets.

import math
from pyspark.sql import functions as f

def quantile(q, cols):
    if q < 0 or q > 1:
        raise ValueError("Parameter q should be 0 <= q <= 1")

    if not cols:
        raise ValueError("List of columns should be provided")

    idx = (len(cols) - 1) * q
    i = math.floor(idx)
    j = math.ceil(idx)
    fraction = idx - i

    arr = f.array_sort(f.array(*cols))

    return arr.getItem(i) + (arr.getItem(j) - arr.getItem(i)) * fraction


columns = ['col_a', 'col_b', 'col_c']

df.withColumn('0.25%', quantile(0.25, columns)).show()

+-----+--------+-----+--------+-----+-------+
|value|   col_a|col_b|   col_c|col_d|  0.25%|
+-----+--------+-----+--------+-----+-------+
|row_a|     5.0|  0.0|    11.0|    1|    2.5|
|row_b|  3394.0|  0.0|  4543.0|    1| 1697.0|
|row_c|136111.0|  0.0|219255.0|    1|68055.5|
|row_d|     0.0|  0.0|     0.0|    1|    0.0|
|row_e|     0.0|  0.0|     0.0|    1|    0.0|
|row_f|    42.0|  0.0|    54.0|    1|   21.0|
+-----+--------+-----+--------+-----+-------+

As a side note, there is also the pandas API on spark, however axis=1 is not (yet) implemented. Potentially this will be added in the future.

df.to_pandas_on_spark().quantile(0.25, axis=1)

NotImplementedError: axis should be either 0 or "index" currently.
like image 39
ScootCork Avatar answered Dec 08 '25 23:12

ScootCork



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!