Spark UDAF with ArrayType as bufferSchema performance issues

Solution 1:

TL;DR Either don't use UDAF or use primitive types in place of ArrayType.

Without UserDefinedFunction

Both solutions should skip expensive juggling between internal and external representation.

Using standard aggregates and pivot

This uses standard SQL aggregations. While optimized internally it might be expensive when number of keys and size of the array grow.

Given input:

val df = Seq((1, 2, 1), (1, 3, 1), (1, 2, 3)).toDF("id", "index", "value")

You can:

import org.apache.spark.sql.functions.{array, coalesce, col, lit}

val nBuckets = 10
@transient val values = array(
  0 until nBuckets map (c => coalesce(col(c.toString), lit(0))): _*
)

df
  .groupBy("id")
  .pivot("index", 0 until nBuckets)
  .sum("value")
  .select($"id", values.alias("values"))
+---+--------------------+                                                      
| id|              values|
+---+--------------------+
|  1|[0, 0, 4, 1, 0, 0...|
+---+--------------------+

Using RDD API with combineByKey / aggregateByKey.

Plain old byKey aggregation with mutable buffer. No bells and whistles but should perform reasonably well with wide range of inputs. If you suspect input to be sparse, you may consider more efficient intermediate representation, like mutable Map.

rdd
  .aggregateByKey(Array.fill(nBuckets)(0L))(
    { case (acc, (index, value)) => { acc(index) += value; acc }},
    (acc1, acc2) => { for (i <- 0 until nBuckets) acc1(i) += acc2(i); acc1}
  ).toDF
+---+--------------------+
| _1|                  _2|
+---+--------------------+
|  1|[0, 0, 4, 1, 0, 0...|
+---+--------------------+

Using UserDefinedFunction with primitive types

As far as I understand the internals, performance bottleneck is ArrayConverter.toCatalystImpl.

It look like it is called for each call MutableAggregationBuffer.update, and in turn allocates new GenericArrayData for each Row.

If we redefine bufferSchema as:

def bufferSchema: StructType = {
  StructType(
    0 to nBuckets map (i => StructField(s"x$i", LongType))
  )
}

both update and merge can be expressed as plain replacements of primitive values in the buffer. Call chain will remain pretty long, but it won't require copies / conversions and crazy allocations. Omitting null checks you'll need something similar to

val index = input.getLong(0)
buffer.update(index, buffer.getLong(index) + input.getLong(1))

and

for(i <- 0 to nBuckets){
  buffer1.update(i, buffer1.getLong(i) + buffer2.getLong(i))
}

respectively.

Finally evaluate should take Row and convert it to output Seq:

 for (i <- 0 to nBuckets)  yield buffer.getLong(i)

Please note that in this implementation a possible bottleneck is merge. While it shouldn't introduce any new performance problems, with M buckets, each call to merge is O(M).

With K unique keys, and P partitions it will be called M * K times in the worst case scenario, where each key, occurs at least once on each partition. This effectively increases complicity of the merge component to O(M * N * K).

In general there is not much you can do about it. However if you make specific assumptions about the data distribution (data is sparse, key distribution is uniform), you can shortcut things a bit, and shuffle first:

df
  .repartition(n, $"key")
  .groupBy($"key")
  .agg(SumArrayAtIndexUDAF($"index", $"value"))

If the assumptions are satisfied it should:

  • Counterintuitively reduce shuffle size by shuffling sparse pairs, instead of dense array-like Rows.
  • Aggregate data using updates only (each O(1)) possibly touching only as subset of indices.

However if one or both assumptions are not satisfied, you can expect that shuffle size will increase while number of updates will stay the same. At the same time data skews can make things even worse than in update - shuffle - merge scenario.

Using Aggregator with "strongly" typed Dataset:

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.{Encoder, Encoders}

class SumArrayAtIndex[I](f: I => (Int, Long))(bucketSize: Int)  extends Aggregator[I, Array[Long], Seq[Long]]
    with Serializable {
  def zero = Array.fill(bucketSize)(0L)
  def reduce(acc: Array[Long], x: I) = {
    val (i, v) = f(x)
    acc(i) += v
    acc
  }

  def merge(acc1: Array[Long], acc2: Array[Long]) = {
    for {
      i <- 0 until bucketSize
    } acc1(i) += acc2(i)
    acc1
  }

  def finish(acc: Array[Long]) = acc.toSeq

  def bufferEncoder: Encoder[Array[Long]] = Encoders.kryo[Array[Long]]
  def outputEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
}

which could be used as shown below

val ds = Seq((1, (1, 3L)), (1, (2, 5L)), (1, (0, 1L)), (1, (4, 6L))).toDS

ds
  .groupByKey(_._1)
  .agg(new SumArrayAtIndex[(Int, (Int, Long))](_._2)(10).toColumn)
  .show(false)
+-----+-------------------------------+
|value|SumArrayAtIndex(scala.Tuple2)  |
+-----+-------------------------------+
|1    |[1, 3, 5, 0, 6, 0, 0, 0, 0, 0] |
|2    |[0, 11, 0, 0, 0, 0, 0, 0, 0, 0]|
+-----+-------------------------------+

Note:

See also SPARK-27296 - User Defined Aggregating Functions (UDAFs) have a major efficiency problem