Pyspark: aggregate mode (most frequent) value in a rolling window
I have a dataframe such as follows. I would like to group by device
and order by start_time
within each group. Then, for each row in the group, get the most frequently occurring station from a window of 3 rows before it (including itself).
columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]
test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)
Desired output:
+------+----------+---------+--------------------+
|device|start_time| station|rolling_mode_station|
+------+----------+---------+--------------------+
|Python| 1|station_1| station_1|
|Python| 2|station_2| station_2|
|Python| 3|station_1| station_1|
|Python| 4|station_2| station_2|
|Python| 5|station_2| station_2|
|Python| 6| null| station_2|
+------+----------+---------+--------------------+
Since Pyspark does not have a mode()
function, I know how to get the most frequent value in a static groupby
as shown here, but I don't know how to adapt it to a rolling window.
Solution 1:
You can use collect_list
function to get the stations from last 3 rows using the defined window, then for each resulting array calculate the most frequent element.
To get the most frequent element on the array, you can explode it then group by and count as in linked post your already saw or use some UDF like this:
import pyspark.sql.functions as F
test_df.withColumn(
"rolling_mode_station",
F.collect_list("station").over(rolling_w)
).withColumn(
"rolling_mode_station",
F.udf(lambda x: max(set(x), key=x.count))(F.col("rolling_mode_station"))
).show()
#+------+----------+---------+--------------------+
#|device|start_time| station|rolling_mode_station|
#+------+----------+---------+--------------------+
#|Python| 1|station_1| station_1|
#|Python| 2|station_2| station_1|
#|Python| 3|station_1| station_1|
#|Python| 4|station_2| station_2|
#|Python| 5|station_2| station_2|
#|Python| 6| null| station_2|
#+------+----------+---------+--------------------+