I have a Pyspark dataframe that I want to perform batch processing on. The DF have an id and a count.
| ------ | ----- |
| ID | Count |
| ------ | ----- |
| abc | 500 |
| def | 300 |
| ghi | 400 |
| jkl | 200 |
| mno | 1100 |
| pqr | 900 |
I want to create batches (list of IDs) that sum up to the first time that it exceeds a threshold. Lets say 1000.
In above the first batch would be ['abc','def','ghi']
(500+300+400)=1200.
The second batch would be ['jkl','pqr']
(200+900)=1100.
['mno']
should be a batch by itself.
I already developed the code to strip out the 'mno' as the first step. But how can I develop an iterative process to create the rest of the batches? The order of the IDs do not matter. If I can avoid the processing of an ordered DF first I would prefer that.
There might be an easier way, but using floor ends up not keeping the row that sends it over 1000, so a resetting cumsum type approach is the only way I could think to do it:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
import pandas as pd
data = [
("abc", 500),
("def", 300),
("ghi", 400),
("jkl", 200),
("mno", 1100),
("pqr", 900),
]
df = spark.createDataFrame(data, ["ID", "Count"])
# Split off the larger groups
df_big_count = df.where(df.Count >= 1000)
df_small_count = df.where(df.Count < 1000)
# Assign row numbers so we can use it for correctly numbering the larger groups later
df_small_count = df_small_count.withColumn(
"rn", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)
df_big_count = df_big_count.withColumn(
"rn", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)
# Floor and other methods don't include the trailing ID that sends it over the threshold
# You would need some sort of resetting counter
@F.pandas_udf(IntegerType())
def assign_groups(counts: pd.Series) -> pd.Series:
groups = []
running_total = 0
group_id = 0
for c in counts:
running_total += c
groups.append(group_id)
if running_total >= 1000:
group_id += 1
running_total = 0
return pd.Series(groups)
df_small_count = df_small_count.withColumn("group", assign_groups(F.col("Count")))
# Make sure the bigger group numbers start counting where the smaller count df finished
df_big_count = df_big_count.withColumn(
"group",
F.col("rn") + df_small_count.agg({"group": "max"}).collect()[0][0]+1,
)
out = df_small_count.unionByName(df_big_count, allowMissingColumns=True)
out.select('ID','Count','group').show()
+---+-----+-----+
| ID|Count|group|
+---+-----+-----+
|abc| 500| 0|
|def| 300| 0|
|ghi| 400| 0|
|jkl| 200| 1|
|pqr| 900| 1|
|mno| 1100| 2|
+---+-----+-----+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With