The following is my dataframe:
df = spark.createDataFrame([
(0, 1),
(0, 2),
(0, 5),
(1, 1),
(1, 2),
(1, 3),
(1, 5),
(2, 1),
(2, 2)
], ["id", "product"])
I need to do a groupBy of id and collect all the items as shown below, but I need to check the product count and if it is less than 2, that should not be there it collected items.
For example, product 3 is repeated only once, i.e. count of 3 is 1, which is less than 2, so it should not be available in following dataframe. Looks like I need to do two groupBys:
Expected output:
+---+------------+
| id| items|
+---+------------+
| 0| [1, 2, 5]|
| 1| [1, 2, 5]|
| 2| [1, 2]|
+---+------------+
I think indeed two groupBy's is a decent solution, you can use a leftsemi join after the first groupBy to filter your intial DataFrame. Working example solution:
import pyspark.sql.functions as F
df = spark.createDataFrame([
(0, 1),
(0, 2),
(0, 5),
(1, 1),
(1, 2),
(1, 3),
(1, 5),
(2, 1),
(2, 2)
], ["id", "product"])
df = df\
.join(df.groupBy('product').count().filter(F.col('count')>=2),'product','leftsemi').distinct()\
.orderBy(F.col('id'),F.col('product'))\
.groupBy('id').agg(F.collect_list('product').alias('product'))
df.show()
Where the orderBy clause is optional, only if you care about the ordering in the result. Output:
+---+---------+
| id| product|
+---+---------+
| 0|[1, 2, 5]|
| 1|[1, 2, 5]|
| 2| [1, 2]|
+---+---------+
Hope this helps!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With