Group by and save the max value with overlapping columns in scala spark
You can use a cumulative conditional sum along with lag
function to define group
column that flags rows that overlap. Then, simply group by customerid
+ group
and get min start
and max expiration
. To get the id
value associated with max expiration date, you can use this trick with struct ordering:
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy("customerid").orderBy("start")
val result = df.withColumn(
"group",
sum(
when(
col("start").between(lag("start", 1).over(w), lag("expiration", 1).over(w)),
0
).otherwise(1)
).over(w)
).groupBy("customerid", "group").agg(
min(col("start")).as("start"),
max(struct(col("expiration"), col("id"))).as("max")
).select("max.id", "customerid", "start", "max.expiration")
result.show
//+---+----------+-----+----------+
//| id|customerid|start|expiration|
//+---+----------+-----+----------+
//| 5| 0002|01321| 02143|
//| 4| 0002|39271| 40231|
//| 2| 0001|11943| 28432|
//+---+----------+-----+----------+