Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark create new column based on other column with multiple condition with list or set

I am trying to create a new column in pyspark data frame. I have data like following

+------+
|letter|
+------+
|     A|
|     C|
|     A|
|     Z|
|     E|
+------+

I want to add a new column based on the given column according to

+------+-----+
|letter|group|
+------+-----+
|     A|   c1|
|     B|   c1|
|     F|   c2|
|     G|   c2|
|     I|   c3|
+------+-----+

There can be multiple categories, with many individual values of letters (around 100, also containing multiple letters)

I have done this with udf, and working well

from pyspark.sql.functions import udf
from pyspark.sql.types import *

c1 = ['A','B','C','D']
c2 = ['E','F','G','H']
c3 = ['I','J','K','L']
...

def l2c(value):
    if value in c1: return 'c1'
    elif value in c2: return 'c2'
    elif value in c3: return 'c3'
    else: return "na"

udf_l2c = udf(l2c, StringType())
data_with_category = data.withColumn("group", udf_l2c("letter"))

Now I am trying to do it without udf. Maybe using when and col. What I have tried is following. It is working, but very long code.

data_with_category = data.withColumn('group', when(col('letter') == 'A' ,'c1')
    .when(col('letter') == 'B', 'c1')
    .when(col('letter') == 'F', 'c2')
    ... 

It is very long and not very good to write new when condition for all possible values of letter. The number of letters can be very large (around 100) in my case. so I tried

data_with_category = data.withColumn('group', when(col('letter') in ['A','B','C','D'] ,'c1')
    .when(col('letter') in ['E','F','G','H'], 'c2')
    .when(col('letter') in ['I','J','K','L'], 'c3')

But it returns error. How can I solve this?

like image 767
Prabhu Avatar asked Aug 31 '25 01:08

Prabhu


2 Answers

Use isin.

c1 = ['A','B','C','D']
c2 =['E','F','G','H']
c3 =['I','J','K','L']

df.withColumn("group", F.when(F.col("letter").isin(c1),F.lit('c1'))\
                        .when(F.col("letter").isin(c2),F.lit('c2'))\
                        .when(F.col("letter").isin(c3),F.lit('c3'))).show()

#+------+-----+
#|letter|group|
#+------+-----+
#|     A|   c1|
#|     B|   c1|
#|     F|   c2|
#|     G|   c2|
#|     I|   c3|
#+------+-----+
like image 187
murtihash Avatar answered Sep 02 '25 17:09

murtihash


you can try to using udf, for example:

say_hello_udf = udf(lambda name: say_hello(name), StringType())
df = spark.createDataFrame([("Rick,"),("Morty,")], ["name"])
df.withColumn("greetings", say_hello_udf(col("name")).show()

or

@udf(returnType=StringType())
def say_hello(name):
   return f"Hello {name}"
df.withColumn("greetings", say_hello(col("name")).show()
like image 40
Phạm Ngọc Quý Avatar answered Sep 02 '25 16:09

Phạm Ngọc Quý