Avoid performance impact of a single partition mode in Spark window functions
In practice performance impact will be almost the same as if you omitted partitionBy
clause at all. All records will be shuffled to a single partition, sorted locally and iterated sequentially one by one.
The difference is only in the number of partitions created in total. Let's illustrate that with an example using simple dataset with 10 partitions and 1000 records:
df = spark.range(0, 1000, 1, 10).toDF("index").withColumn("col1", f.randn(42))
If you define frame without partition by clause
w_unpart = Window.orderBy(f.col("index").asc())
and use it with lag
df_lag_unpart = df.withColumn(
"diffs_col1", f.lag("col1", 1).over(w_unpart) - f.col("col1")
)
there will be only one partition in total:
df_lag_unpart.rdd.glom().map(len).collect()
[1000]
Compared to that frame definition with dummy index (simplified a bit compared to your code:
w_part = Window.partitionBy(f.lit(0)).orderBy(f.col("index").asc())
will use number of partitions equal to spark.sql.shuffle.partitions
:
spark.conf.set("spark.sql.shuffle.partitions", 11)
df_lag_part = df.withColumn(
"diffs_col1", f.lag("col1", 1).over(w_part) - f.col("col1")
)
df_lag_part.rdd.glom().count()
11
with only one non-empty partition:
df_lag_part.rdd.glom().filter(lambda x: x).count()
1
Unfortunately there is no universal solution which can be used to address this problem in PySpark. This just an inherent mechanism of the implementation combined with distributed processing model.
Since index
column is sequential you could generate artificial partitioning key with fixed number of records per block:
rec_per_block = df.count() // int(spark.conf.get("spark.sql.shuffle.partitions"))
df_with_block = df.withColumn(
"block", (f.col("index") / rec_per_block).cast("int")
)
and use it to define frame specification:
w_with_block = Window.partitionBy("block").orderBy("index")
df_lag_with_block = df_with_block.withColumn(
"diffs_col1", f.lag("col1", 1).over(w_with_block) - f.col("col1")
)
This will use expected number of partitions:
df_lag_with_block.rdd.glom().count()
11
with roughly uniform data distribution (we cannot avoid hash collisions):
df_lag_with_block.rdd.glom().map(len).collect()
[0, 180, 0, 90, 90, 0, 90, 90, 100, 90, 270]
but with a number of gaps on the block boundaries:
df_lag_with_block.where(f.col("diffs_col1").isNull()).count()
12
Since boundaries are easy to compute:
from itertools import chain
boundary_idxs = sorted(chain.from_iterable(
# Here we depend on sequential identifiers
# This could be generalized to any monotonically increasing
# id by taking min and max per block
(idx - 1, idx) for idx in
df_lag_with_block.groupBy("block").min("index")
.drop("block").rdd.flatMap(lambda x: x)
.collect()))[2:] # The first boundary doesn't carry useful inf.
you can always select:
missing = df_with_block.where(f.col("index").isin(boundary_idxs))
and fill these separately:
# We use window without partitions here. Since number of records
# will be small this won't be a performance issue
# but will generate "Moving all data to a single partition" warning
missing_with_lag = missing.withColumn(
"diffs_col1", f.lag("col1", 1).over(w_unpart) - f.col("col1")
).select("index", f.col("diffs_col1").alias("diffs_fill"))
and join
:
combined = (df_lag_with_block
.join(missing_with_lag, ["index"], "leftouter")
.withColumn("diffs_col1", f.coalesce("diffs_col1", "diffs_fill")))
to get desired result:
mismatched = combined.join(df_lag_unpart, ["index"], "outer").where(
combined["diffs_col1"] != df_lag_unpart["diffs_col1"]
)
assert mismatched.count() == 0