Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to create a map column with rolling window aggregates per each key

Problem description

I need help with a pyspark.sql function that will create a new variable aggregating records over a specified Window() into a map of key-value pairs.

Reproducible Data

df = spark.createDataFrame(
    [
        ('AK', "2022-05-02", 1651449600, 'US', 3), 
        ('AK', "2022-05-03", 1651536000, 'ON', 1),
        ('AK', "2022-05-04", 1651622400, 'CO', 1),
        ('AK', "2022-05-06", 1651795200, 'AK', 1),
        ('AK', "2022-05-06", 1651795200, 'US', 5)
    ],
    ["state", "ds", "ds_num", "region", "count"]
)

df.show()
# +-----+----------+----------+------+-----+
# |state|        ds|    ds_num|region|count|
# +-----+----------+----------+------+-----+
# |   AK|2022-05-02|1651449600|    US|    3|
# |   AK|2022-05-03|1651536000|    ON|    1|
# |   AK|2022-05-04|1651622400|    CO|    1|
# |   AK|2022-05-06|1651795200|    AK|    1|
# |   AK|2022-05-06|1651795200|    US|    5|
# +-----+----------+----------+------+-----+

Partial solutions

Sets of regions over a window frame:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

days = lambda i: i * 86400

df.withColumn('regions_4W', 
              F.collect_set('region').over(
                  Window().partitionBy('state').orderBy('ds_num').rangeBetween(-days(27),0)))\
.sort('ds')\
.show()

# +-----+----------+----------+------+-----+----------------+
# |state|        ds|    ds_num|region|count|      regions_4W|
# +-----+----------+----------+------+-----+----------------+
# |   AK|2022-05-02|1651449600|    US|    3|            [US]|
# |   AK|2022-05-03|1651536000|    ON|    1|        [US, ON]|
# |   AK|2022-05-04|1651622400|    CO|    1|    [CO, US, ON]|
# |   AK|2022-05-06|1651795200|    AK|    1|[CO, US, ON, AK]|
# |   AK|2022-05-06|1651795200|    US|    5|[CO, US, ON, AK]|
# +-----+----------+----------+------+-----+----------------+

Map of counts per each state and ds

df\
.groupby('state', 'ds', 'ds_num')\
.agg(F.map_from_entries(F.collect_list(F.struct("region", "count"))).alias("count_rolling_4W"))\
.sort('ds')\
.show()

# +-----+----------+----------+------------------+
# |state|        ds|    ds_num|  count_rolling_4W|
# +-----+----------+----------+------------------+
# |   AK|2022-05-02|1651449600|         {US -> 3}|
# |   AK|2022-05-03|1651536000|         {ON -> 1}|
# |   AK|2022-05-04|1651622400|         {CO -> 1}|
# |   AK|2022-05-06|1651795200|{AK -> 1, US -> 5}|
# +-----+----------+----------+------------------+

Desired Output

What I need is a map aggregating data per each key present in the specified window

+-----+----------+----------+------------------------------------+
|state|        ds|    ds_num|                    count_rolling_4W|
+-----+----------+----------+------------------------------------+
|   AK|2022-05-02|1651449600|                           {US -> 3}|
|   AK|2022-05-03|1651536000|                  {US -> 3, ON -> 1}|
|   AK|2022-05-04|1651622400|         {US -> 3, ON -> 1, CO -> 1}|
|   AK|2022-05-06|1651795200|{US -> 8, ON -> 1, CO -> 1, AK -> 1}|
+-----+----------+----------+------------------------------------+
like image 454
Adrian Avatar asked Sep 12 '25 07:09

Adrian


1 Answers

You can use higher order functions transform and aggregate like this:

from pyspark.sql import Window, functions as F

w = Window.partitionBy('state').orderBy('ds_num').rowsBetween(-days(27), 0)

df1 = (df.withColumn('regions', F.collect_set('region').over(w))
       .withColumn('counts', F.collect_list(F.struct('region', 'count')).over(w))
       .withColumn('counts',
                   F.transform(
                       'regions',
                       lambda x: F.aggregate(
                           F.filter('counts', lambda y: y['region'] == x),
                           F.lit(0),
                           lambda acc, v: acc + v['count']
                       )
                   ))
       .withColumn('count_rolling_4W', F.map_from_arrays('regions', 'counts'))
       .drop('counts', 'regions')
       )

df1.show(truncate=False)

#+-----+----------+----------+------+-----+------------------------------------+
# |state|ds        |ds_num    |region|count|count_rolling_4W                    |
# +-----+----------+----------+------+-----+------------------------------------+
# |AK   |2022-05-02|1651449600|US    |3    |{US -> 3}                           |
# |AK   |2022-05-03|1651536000|ON    |1    |{US -> 3, ON -> 1}                  |
# |AK   |2022-05-04|1651622400|CO    |1    |{CO -> 1, US -> 3, ON -> 1}         |
# |AK   |2022-05-06|1651795200|AK    |1    |{CO -> 1, US -> 3, ON -> 1, AK -> 1}|
# |AK   |2022-05-06|1651795200|US    |5    |{CO -> 1, US -> 8, ON -> 1, AK -> 1}|
# +-----+----------+----------+------+-----+------------------------------------+
like image 147
blackbishop Avatar answered Sep 15 '25 19:09

blackbishop