Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use LinearRegression across groups in DataFrame?

Let us say my spark DataFrame (DF) looks like

id | age | earnings| health 
----------------------------
1  | 34  | 65      | 8
2  | 65  | 12      | 4
2  | 20  | 7       | 10
1  | 40  | 75      | 7
.  | ..  | ..      | ..

and I would like to group the DF, apply a function (say linear regression which depends on multiple columns - two columns in this case - of aggregated DF) on each aggregated DF and get output like

id | intercept| slope 
----------------------
1  |   ?      |  ? 
2  |   ?      |  ? 
from sklearn.linear_model import LinearRegression
lr_object = LinearRegression()

def linear_regression(ith_DF):
    # Note: for me it is necessary that ith_DF should contain all 
    # data within this function scope, so that I can apply any 
    # function that needs all data in ith_DF

    X = [i.earnings for i in ith_DF.select("earnings").rdd.collect()]
    y = [i.health for i in ith_DF.select("health").rdd.collect()]

    lr_object.fit(X, y)
    return lr_object.intercept_, lr_object.coef_[0]

coefficient_collector = []

# following iteration is not possible in spark as 'GroupedData' 
# object is not iterable, please consider it as pseudo code

for ith_df in df.groupby("id"): 
    c, m = linear_regression(ith_df)
    coefficient_collector.append((float(c), float(m)))

model_df = spark.createDataFrame(coefficient_collector, ["intercept", "slope"])
model_df.show()
like image 272
Everest Avatar asked Oct 15 '25 10:10

Everest


1 Answers

I think this can be done since Spark 2.3 using pandas_UDF. In fact, there is an example of fitting grouped regressions on the announcement of pandas_UDFs here:

Introducing Pandas UDF for Python

like image 92
James Kirkby Avatar answered Oct 18 '25 00:10

James Kirkby



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!