DataFrame / Dataset groupBy behaviour/optimization
Suppose we have DataFrame df
consisting of the following columns:
Name, Surname, Size, Width, Length, Weigh
Now we want to perform a couple of operations, for example we want to create a couple of DataFrames containing data about Size and Width.
val df1 = df.groupBy("surname").agg( sum("size") )
val df2 = df.groupBy("surname").agg( sum("width") )
as you can notice, other columns, like Length are not used anywhere. Is Spark smart enough to drop the redundant columns before the shuffling phase or are they carried around? Wil running:
val dfBasic = df.select("surname", "size", "width")
before grouping somehow affect the performance?
Solution 1:
Yes, it is "smart enough". groupBy
performed on a DataFrame
is not the same operation as groupBy
performed on a plain RDD. In a scenario you've described there is no need to move raw data at all. Let's create a small example to illustrate that:
val df = sc.parallelize(Seq(
("a", "foo", 1), ("a", "foo", 3), ("b", "bar", 5), ("b", "bar", 1)
)).toDF("x", "y", "z")
df.groupBy("x").agg(sum($"z")).explain
// == Physical Plan ==
// *HashAggregate(keys=[x#148], functions=[sum(cast(z#150 as bigint))])
// +- Exchange hashpartitioning(x#148, 200)
// +- *HashAggregate(keys=[x#148], functions=[partial_sum(cast(z#150 as bigint))])
// +- *Project [_1#144 AS x#148, _3#146 AS z#150]
// +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#144, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#145, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#146]
// +- Scan ExternalRDDScan[obj#143]
As you can the first phase is a projection where only required columns are preserved. Next data is aggregated locally and finally transferred and aggregated globally. You'll get a little bit different answer output if you use Spark <= 1.4 but general structure should be exactly the same.
Finally a DAG visualization showing that above description describes actual job:
Similarly, Dataset.groupByKey
followed by reduceGroups
, contains both map-side (ObjectHashAggregate
with partial_reduceaggregator
) and reduce-side (ObjectHashAggregate
with reduceaggregator
reduction):
case class Foo(x: String, y: String, z: Int)
val ds = df.as[Foo]
ds.groupByKey(_.x).reduceGroups((x, y) => x.copy(z = x.z + y.z)).explain
// == Physical Plan ==
// ObjectHashAggregate(keys=[value#126], functions=[reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
// +- Exchange hashpartitioning(value#126, 200)
// +- ObjectHashAggregate(keys=[value#126], functions=[partial_reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
// +- AppendColumns <function1>, newInstance(class $line40.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#126]
// +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
// +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
// +- Scan ExternalRDDScan[obj#3]
However other methods of KeyValueGroupedDataset
might work similarly to RDD.groupByKey
. For example mapGroups
(or flatMapGroups
) doesn't use partial aggregation.
ds.groupByKey(_.x)
.mapGroups((_, iter) => iter.reduce((x, y) => x.copy(z = x.z + y.z)))
.explain
//== Physical Plan ==
//*SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).x, true, false) AS x#37, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).y, true, false) AS y#38, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).z AS z#39]
//+- MapGroups <function2>, value#32.toString, newInstance(class $line15.$read$$iw$$iw$Foo), [value#32], [x#8, y#9, z#10], obj#36: $line15.$read$$iw$$iw$Foo
// +- *Sort [value#32 ASC NULLS FIRST], false, 0
// +- Exchange hashpartitioning(value#32, 200)
// +- AppendColumns <function1>, newInstance(class $line15.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#32]
// +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
// +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
// +- Scan ExternalRDDScan[obj#3]