Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Create list of id's until the first time it exceeds a specific count

Tags:

python

pyspark

I have a Pyspark dataframe that I want to perform batch processing on. The DF have an id and a count.

| ------ | ----- |
| ID     | Count |
| ------ | ----- |
| abc    | 500   |
| def    | 300   |
| ghi    | 400   |
| jkl    | 200   |
| mno    | 1100  |
| pqr    | 900   |

I want to create batches (list of IDs) that sum up to the first time that it exceeds a threshold. Lets say 1000. In above the first batch would be ['abc','def','ghi'] (500+300+400)=1200.

The second batch would be ['jkl','pqr'] (200+900)=1100.

['mno'] should be a batch by itself.

I already developed the code to strip out the 'mno' as the first step. But how can I develop an iterative process to create the rest of the batches? The order of the IDs do not matter. If I can avoid the processing of an ordered DF first I would prefer that.

like image 780
fstr Avatar asked Sep 01 '25 03:09

fstr


1 Answers

There might be an easier way, but using floor ends up not keeping the row that sends it over 1000, so a resetting cumsum type approach is the only way I could think to do it:

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
import pandas as pd

data = [
    ("abc", 500),
    ("def", 300),
    ("ghi", 400),
    ("jkl", 200),
    ("mno", 1100),
    ("pqr", 900),
]
df = spark.createDataFrame(data, ["ID", "Count"])

# Split off the larger groups
df_big_count = df.where(df.Count >= 1000)
df_small_count = df.where(df.Count < 1000)

# Assign row numbers so we can use it for correctly numbering the larger groups later
df_small_count = df_small_count.withColumn(
    "rn", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)
df_big_count = df_big_count.withColumn(
    "rn", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)

# Floor and other methods don't include the trailing ID that sends it over the threshold
# You would need some sort of resetting counter
@F.pandas_udf(IntegerType())
def assign_groups(counts: pd.Series) -> pd.Series:
    groups = []
    running_total = 0
    group_id = 0
    for c in counts:
        running_total += c
        groups.append(group_id)
        if running_total >= 1000:
            group_id += 1
            running_total = 0
    return pd.Series(groups)


df_small_count = df_small_count.withColumn("group", assign_groups(F.col("Count")))

#  Make sure the bigger group numbers start counting where the smaller count df finished
df_big_count = df_big_count.withColumn(
    "group",
    F.col("rn") + df_small_count.agg({"group": "max"}).collect()[0][0]+1,
)

out = df_small_count.unionByName(df_big_count, allowMissingColumns=True)

out.select('ID','Count','group').show()


+---+-----+-----+
| ID|Count|group|
+---+-----+-----+
|abc|  500|    0|
|def|  300|    0|
|ghi|  400|    0|
|jkl|  200|    1|
|pqr|  900|    1|
|mno| 1100|    2|
+---+-----+-----+
like image 62
Chris Avatar answered Sep 02 '25 16:09

Chris