Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pyspark when/otherwise clause failure when using udf

I have a udf function which takes the key and return the corresponding value from name_dict.

from pyspark.sql import *
from pyspark.sql.functions import udf, when, col

name_dict = {'James': "manager", 'Robert': 'director'}
func = udf(lambda name: name_dict[name])

The original dataframe: James and Robert are in the dict, but Michael is not.

data = [("James","M"),("Michael","M"),("Robert",None)]
test = spark.createDataFrame(data = data, schema = ['name', 'gender'])
test.show()
+-------+------+
|   name|gender|
+-------+------+
|  James|     M|
|Michael|     M|
| Robert|  null|
+-------+------+

To prevent KeyError, I use the when condition to filter the rows before any operation, but it does not work.

test.withColumn('senior', when(col('name').isin(['James', 'Robert']), func(col('name'))).otherwise(col('gender'))).show()

PythonException: An exception was thrown from a UDF: 'KeyError: 'Michael'', from , line 8. Full traceback below...

What is the cause of this and are there any feasible ways to solve this problem? Assume that not all the names are keys of the dictionary and for those that are not included, I would like to copy the value from another column, say gender here.

like image 251
Haley Avatar asked Oct 14 '25 07:10

Haley


1 Answers

This actually the behavior of user-defined functions in Spark. You can read from the docs:

The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions.

So in your case you need to rewrite your UDF as:

func = udf(lambda name: name_dict.get(name, "NA"))

Then calling it using:

test.withColumn('senior', func(col('name'))).show()

#+-------+------+--------+
#|   name|gender|  senior|
#+-------+------+--------+
#|  James|     M| manager|
#|Michael|     M|      NA|
#| Robert|  null|director|
#+-------+------+--------+

However, in you case you can actually do this without having to use udf, by using a map column:

from itertools import chain
from pyspark.sql.functions import col, create_map, lit

map_col = create_map(*[lit(x) for x in chain(*name_dict.items())])
test.withColumn('senior', map_col[col('name')]).show()
like image 83
blackbishop Avatar answered Oct 16 '25 23:10

blackbishop