Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I create multiple columns from one condition using withColumns in Pyspark?

I'd like to create multiple columns in a pyspark dataframe with one condition (adding more later).

I tried this but it doesn't work:

df.withColumns(F.when(F.col('age') < 6, {'new_c1': F.least(F.col('c1'), F.col('c2')), 
                                  'new_c2': F.least(F.col('c1'), F.col('c3')),
                                  'new_c3': F.least(F.col('c1'), F.col('c4'))}))

In English, when age < 6, create three new columns based on the minimum value of other columns.

Does withColumns take when() and otherwise() as withColumn does? The documentation doesn't say.

I suppose I could separate these into individual statements but I hoped I could do it in one shot. Do I need a UDF?

like image 303
Chuck Avatar asked Sep 03 '25 16:09

Chuck


1 Answers

As stated in the documentation, the withColumns function takes as input "a dict of column name and Column. Currently, only single map is supported". In your case, you pass the dictionary inside of a when function, which is not supported and thus does not yield the dictionary expected by withColumns. To avoid repeating the condition three times and be a bit generic, you can augment all the values of your dictionary with your condition like this:

df = spark.createDataFrame([(3, 1, 2, 3, 4), (10, 9, 8, 7, 6)], ['age', 'c1', 'c2', 'c3', 'c4'])
d = {'new_c1': F.least(F.col('c1'), F.col('c2')), 
     'new_c2': F.least(F.col('c1'), F.col('c3')),
     'new_c3': F.least(F.col('c1'), F.col('c4'))}
df.withColumns(dict([(k, F.when(F.col('age') < 6, d[k])) for k in d])).show()
+---+---+---+---+---+------+------+------+
|age| c1| c2| c3| c4|new_c1|new_c2|new_c3|
+---+---+---+---+---+------+------+------+
|  3|  1|  2|  3|  4|     1|     1|     1|
| 10|  9|  8|  7|  6|  null|  null|  null|
+---+---+---+---+---+------+------+------+

Or a different approach, simply starting from the list of columns you are comparing c1 to:

cols = ['c2', 'c3', 'c4']
col_list = [('new_' + c, F.when(F.col("age")<6, F.least(F.col('c1'), F.col(c)))) for c in cols ]
df.withColumns(dict(col_list)).show()
like image 84
Oli Avatar answered Sep 05 '25 09:09

Oli