filtering spark dataframe based on label changes in time series
Input dataframe has 4 columns - id (str
), group (str
), elapsed time in days (int
) and label (int
).
inp = spark.createDataFrame([
['1', "A", 23, 2],
['1', "A", 45, 2],
['1', "A", 73, 2],
['1', "A", 84, 3],
['1', "A", 95, 3],
['1', "A", 101, 2],
['1', "A", 105, 2],
['1', "B", 20, 1],
['1', "B", 40, 1],
['1', "B", 60, 2],
['2', "A", 10, 4],
['2', "A", 20, 4],
['2', "A", 30, 4]
], schema=["id","grp","elap","lbl"])
For every (id,grp) I need the output frame to have records with the first occurence of a different label.
out = spark.createDataFrame([
['1', "A", 23, 2],
['1', "A", 84, 3],
['1', "A", 101, 2],
['1', "B", 20, 1],
['1', "B", 60, 2],
['2', "A", 10, 4],
], schema=["id","grp","elap","lbl"])
The dataframe has a billion rows and looking for an efficient way to do this.
Solution 1:
Check if current label is not equal to previous label (group by id
and grp
):
from pyspark.sql.window import Window
import pyspark.sql.functions as f
inp.withColumn('prevLbl', f.lag('lbl').over(Window.partitionBy('id', 'grp').orderBy('elap')))\
.filter(f.col('prevLbl').isNull() | (f.col('prevLbl') != f.col('lbl')))\
.drop('prevLbl').show()
+---+---+----+---+
| id|grp|elap|lbl|
+---+---+----+---+
| 1| A| 23| 2|
| 1| A| 84| 3|
| 1| A| 101| 2|
| 1| B| 20| 1|
| 1| B| 60| 2|
| 2| A| 10| 4|
+---+---+----+---+