filtering spark dataframe based on label changes in time series

Input dataframe has 4 columns - id (str), group (str), elapsed time in days (int) and label (int).

inp = spark.createDataFrame([
  ['1', "A",  23, 2],
  ['1', "A",  45, 2],
  ['1', "A",  73, 2],
  ['1', "A",  84, 3],
  ['1', "A",  95, 3],
  ['1', "A", 101, 2],
  ['1', "A", 105, 2],
  ['1', "B",  20, 1],
  ['1', "B",  40, 1],
  ['1', "B",  60, 2],
  ['2', "A",  10, 4],
  ['2', "A",  20, 4],
  ['2', "A",  30, 4]
], schema=["id","grp","elap","lbl"])

For every (id,grp) I need the output frame to have records with the first occurence of a different label.

out = spark.createDataFrame([
  ['1', "A",  23, 2],
  ['1', "A",  84, 3],
  ['1', "A", 101, 2],
  ['1', "B",  20, 1],
  ['1', "B",  60, 2],
  ['2', "A",  10, 4],
], schema=["id","grp","elap","lbl"])

The dataframe has a billion rows and looking for an efficient way to do this.


Solution 1:

Check if current label is not equal to previous label (group by id and grp):

from pyspark.sql.window import Window
import pyspark.sql.functions as f

inp.withColumn('prevLbl', f.lag('lbl').over(Window.partitionBy('id', 'grp').orderBy('elap')))\
    .filter(f.col('prevLbl').isNull() | (f.col('prevLbl') != f.col('lbl')))\
    .drop('prevLbl').show()

+---+---+----+---+
| id|grp|elap|lbl|
+---+---+----+---+
|  1|  A|  23|  2|
|  1|  A|  84|  3|
|  1|  A| 101|  2|
|  1|  B|  20|  1|
|  1|  B|  60|  2|
|  2|  A|  10|  4|
+---+---+----+---+