Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to create new column based on values in array column in Pyspark

I have the following dataframe with codes which represent products:

testdata = [(0, ['a','b','d']), (1, ['c']), (2, ['d','e'])]
df = spark.createDataFrame(testdata, ['id', 'codes'])
df.show()
+---+---------+
| id|    codes|
+---+---------+
|  0|[a, b, d]|
|  1|      [c]|
|  2|   [d, e]|
+---+---------+

Let's say codes a and b represent t-shirts and code c represents sweaters.

tshirts = ['a','b']
sweaters = ['c']

How can I create a column label which checks whether these codes are in the array column and returns the name of the product. Like so:

+---+---------+--------+
| id|    codes|   label|
+---+---------+--------+
|  0|[a, b, d]| tshirts|
|  1|      [c]|sweaters|
|  2|   [d, e]|    none|
+---+---------+--------+

I have already tried a lot of things, amongst others the following which does not work:

codes = {
    'tshirts': ['a','b'],
    'sweaters': ['c']
}

def any_isin(ref_values, array_to_search):
    for key, values in ref_values.items():
        if any(item in array_to_search for item in values):
            return key
        else:
            return 'none'

any_isin_udf = lambda ref_values: (F.udf(lambda array_to_search: any_isin_mod(ref_values, array_to_search), StringType()))

df_labeled = df.withColumn('label', any_isin_udf(codes)(F.col('codes')))

df_labeled.show()
+---+---------+-------+
| id|    codes|  label|
+---+---------+-------+
|  0|[a, b, d]|tshirts|
|  1|      [c]|   none|
|  2|   [d, e]|   none|
+---+---------+-------+
like image 556
Cheryl Avatar asked Nov 08 '25 09:11

Cheryl


1 Answers

I would expression with array_contains. Let's define input as a dict:

from pyspark.sql.functions import expr, lit, when
from operator import and_
from functools import reduce

label_map = {"tshirts": ["a", "b"], "sweaters": ["c"]}

Next generate expression:

expression_map = {
   label: reduce(and_, [expr("array_contains(codes, '{}')".format(code))
   for code in codes]) for label, codes in label_map.items()
}

Finally reduce it with CASE ... WHEN:

label = reduce(
    lambda acc, kv: when(kv[1], lit(kv[0])).otherwise(acc),
    expression_map.items(), 
    lit(None).cast("string")
).alias("label")

Result:

df.withColumn("label", label).show()
# +---+---------+--------+                                                        
# | id|    codes|   label|
# +---+---------+--------+
# |  0|[a, b, d]| tshirts|
# |  1|      [c]|sweaters|
# |  2|   [d, e]|    null|
# +---+---------+--------+
like image 52
Aaron Makubuya Avatar answered Nov 10 '25 00:11

Aaron Makubuya



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!