I have a udf function which takes the key
and return the corresponding value
from name_dict
.
from pyspark.sql import *
from pyspark.sql.functions import udf, when, col
name_dict = {'James': "manager", 'Robert': 'director'}
func = udf(lambda name: name_dict[name])
The original dataframe: James
and Robert
are in the dict, but Michael
is not.
data = [("James","M"),("Michael","M"),("Robert",None)]
test = spark.createDataFrame(data = data, schema = ['name', 'gender'])
test.show()
+-------+------+
| name|gender|
+-------+------+
| James| M|
|Michael| M|
| Robert| null|
+-------+------+
To prevent KeyError
, I use the when
condition to filter the rows before any operation, but it does not work.
test.withColumn('senior', when(col('name').isin(['James', 'Robert']), func(col('name'))).otherwise(col('gender'))).show()
PythonException: An exception was thrown from a UDF: 'KeyError: 'Michael'', from , line 8. Full traceback below...
What is the cause of this and are there any feasible ways to solve this problem? Assume that not all the names are keys of the dictionary and for those that are not included, I would like to copy the value from another column, say gender
here.
This actually the behavior of user-defined functions in Spark. You can read from the docs:
The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions.
So in your case you need to rewrite your UDF as:
func = udf(lambda name: name_dict.get(name, "NA"))
Then calling it using:
test.withColumn('senior', func(col('name'))).show()
#+-------+------+--------+
#| name|gender| senior|
#+-------+------+--------+
#| James| M| manager|
#|Michael| M| NA|
#| Robert| null|director|
#+-------+------+--------+
However, in you case you can actually do this without having to use udf, by using a map column:
from itertools import chain
from pyspark.sql.functions import col, create_map, lit
map_col = create_map(*[lit(x) for x in chain(*name_dict.items())])
test.withColumn('senior', map_col[col('name')]).show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With