I have a dataframe of format given below.
movieId1 | genreList1 | genreList2
--------------------------------------------------
1 |[Adventure,Comedy] |[Adventure]
2 |[Animation,Drama,War] |[War,Drama]
3 |[Adventure,Drama] |[Drama,War]
and trying to create another flag column which shows whether genreList2 is a subset of genreList1.
movieId1 | genreList1 | genreList2 | Flag
---------------------------------------------------------------
1 |[Adventure,Comedy] | [Adventure] |1
2 |[Animation,Drama,War] | [War,Drama] |1
3 |[Adventure,Drama] | [Drama,War] |0
I have tried this:
def intersect_check(a: Array[String], b: Array[String]): Int = {
if (b.sameElements(a.intersect(b))) { return 1 }
else { return 2 }
}
def intersect_check_udf =
udf((colvalue1: Array[String], colvalue2: Array[String]) => intersect_check(colvalue1, colvalue2))
data = data.withColumn("Flag", intersect_check_udf(col("genreList1"), col("genreList2")))
But this throws error
org.apache.spark.SparkException: Failed to execute user defined function.
P.S.: The above function (intersect_check) works for Arrays.
We can define an udf that calculates the length of the intersection between the two Array columns and checks whether it is equal to the length of the second column. If so, the second array is a subset of the first one.
Also, the inputs of your udf need to be class WrappedArray[String], not Array[String] :
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col
val same_elements = udf { (a: WrappedArray[String],
b: WrappedArray[String]) =>
if (a.intersect(b).length == b.length){ 1 }else{ 0 }
}
df.withColumn("test",same_elements(col("genreList1"),col("genreList2")))
.show(truncate = false)
+--------+-----------------------+------------+----+
|movieId1|genreList1 |genreList2 |test|
+--------+-----------------------+------------+----+
|1 |[Adventure, Comedy] |[Adventure] |1 |
|2 |[Animation, Drama, War]|[War, Drama]|1 |
|3 |[Adventure, Drama] |[Drama, War]|0 |
+--------+-----------------------+------------+----+
Data
val df = List((1,Array("Adventure","Comedy"), Array("Adventure")),
(2,Array("Animation","Drama","War"), Array("War","Drama")),
(3,Array("Adventure","Drama"),Array("Drama","War"))).toDF("movieId1","genreList1","genreList2")
Here is the solution converting using subsetOf
val spark =
SparkSession.builder().master("local").appName("test").getOrCreate()
import spark.implicits._
val data = spark.sparkContext.parallelize(
Seq(
(1,Array("Adventure","Comedy"),Array("Adventure")),
(2,Array("Animation","Drama","War"),Array("War","Drama")),
(3,Array("Adventure","Drama"),Array("Drama","War"))
)).toDF("movieId1", "genreList1", "genreList2")
val subsetOf = udf((col1: Seq[String], col2: Seq[String]) => {
if (col2.toSet.subsetOf(col1.toSet)) 1 else 0
})
data.withColumn("flag", subsetOf(data("genreList1"), data("genreList2"))).show()
Hope this helps!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With