Group by and save the max value with overlapping columns in scala spark

You can use a cumulative conditional sum along with lag function to define group column that flags rows that overlap. Then, simply group by customerid + group and get min start and max expiration. To get the id value associated with max expiration date, you can use this trick with struct ordering:

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("customerid").orderBy("start")

val result = df.withColumn(
    "group",
    sum(
      when(
        col("start").between(lag("start", 1).over(w), lag("expiration", 1).over(w)),
        0
      ).otherwise(1)
    ).over(w)
).groupBy("customerid", "group").agg(
    min(col("start")).as("start"),
    max(struct(col("expiration"), col("id"))).as("max")
).select("max.id", "customerid", "start", "max.expiration")

result.show
//+---+----------+-----+----------+
//| id|customerid|start|expiration|
//+---+----------+-----+----------+
//|  5|      0002|01321|     02143|
//|  4|      0002|39271|     40231|
//|  2|      0001|11943|     28432|
//+---+----------+-----+----------+