Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generic UDAF in Spark 3.0 using Aggregator

Spark 3.0 has deprecated UserDefinedAggregateFunction and I was trying to rewrite my udaf using Aggregator. Basic usage of Aggregator is simple, however, I struggle with more generic version of the function.

I will try to explain my problem with this example, an implementation of collect_set. It's not my actual case, but it's easier to explain the problem:

class CollectSetDemoAgg(name: String) extends Aggregator[Row, Set[Int], Set[Int]] {
  override def zero = Set.empty
  override def reduce(b: Set[Int], a: Row) = b + a.getInt(a.fieldIndex(name))
  override def merge(b1: Set[Int], b2: Set[Int]) = b1 ++ b2
  override def finish(reduction: Set[Int]) = reduction
  override def bufferEncoder = Encoders.kryo[Set[Int]]
  override def outputEncoder = ExpressionEncoder()
}

// using it:
df.agg(new CollectSetDemoAgg("rank").toColumn as "result").show()

I prefer .toColumn vs .udf.register, but it's not the point here.

Problem: I can not make universal version of this Aggregator, it will only work with integers.

I've attempted:

class CollectSetDemo(name: String) extends Aggregator[Row, Set[Any], Set[Any]] 

It crashes with error:

No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
java.lang.UnsupportedOperationException: No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
    at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)

I could not go with CollectSetDemo[T], case I was not able to proper outputEncoder. Also, when using udaf, I can only work with Spark data types, columns, etc.

like image 791
Ramunas Avatar asked Nov 23 '25 10:11

Ramunas


1 Answers

Have not found a nice way to solve the situation, but I was able to somewhat workaround it. Code was partially borrowed from RowEncoder:

class CollectSetDemoAgg(name: String, fieldType: DataType) extends Aggregator[Row, Set[Any], Any] {
  override def zero = Set.empty
  override def reduce(b: Set[Any], a: Row) = b + a.get(a.fieldIndex(name))
  override def merge(b1: Set[Any], b2: Set[Any]) = b1 ++ b2
  override def finish(reduction: Set[Any]) = reduction.toSeq
  override def bufferEncoder = Encoders.kryo[Set[Any]]

  // now
  override def outputEncoder = {
    val mirror = ScalaReflection.mirror
    val tt = fieldType match {
      case ArrayType(LongType, _) => typeTag[Seq[Long]]
      case ArrayType(IntegerType, _) => typeTag[Seq[Int]]
      case ArrayType(StringType, _) => typeTag[Seq[String]]
      // .. etc etc
      case _ => throw new RuntimeException(s"Could not create encoder for ${name} column (${fieldType})")
    }
    val tpe = tt.in(mirror).tpe

    val cls = mirror.runtimeClass(tpe)
    val serializer = ScalaReflection.serializerForType(tpe)
    val deserializer = ScalaReflection.deserializerForType(tpe)

    new ExpressionEncoder[Any](serializer, deserializer, ClassTag[Any](cls))
  }
}

One thing, that I had to add was result data type parameter in aggregator. The usage then changed to:

df.agg(new CollectSetDemoAgg("rank", new ArrayType(IntegerType, true)).toColumn as "result").show()

I really don't like how it turned out, but it works. I also welcome any suggestions how to improve it.

like image 195
Ramunas Avatar answered Nov 24 '25 22:11

Ramunas



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!