How to get unique key from Dataset Spark [duplicate]

I have small dataset:

+-------------------+-------------+--------------+-------+-------------+
|         session_id|  insert_dttm|           key|  value| process_name|
+-------------------+-------------+--------------+-------+-------------+
|local-1641922005078|1641922023703|test_file1.csv|Success|ProcessResult|
|local-1641922005078|1641922023704|test_file1.csv|Success|ProcessResult|
|local-1641922005078|1641922023705|test_file2.csv|Success|ProcessResult|
|local-1641922005078|1641922023706|test_file2.csv|Success|ProcessResult|
|local-1641922005078|1641922023707|test_file3.csv|Success|ProcessResult|
|local-1641922005078|1641922023708|test_file3.csv|Success|ProcessResult|
+-------------------+-------------+--------------+-------+-------------+

I want to get a new dataset only by unique key values at the latest time.

Example Output dataset:

+-------------------+-------------+--------------+-------+-------------+
|         session_id|  insert_dttm|           key|  value| process_name|
+-------------------+-------------+--------------+-------+-------------+
|local-1641922005078|1641922023704|test_file1.csv|Success|ProcessResult|
|local-1641922005078|1641922023706|test_file2.csv|Success|ProcessResult|
|local-1641922005078|1641922023708|test_file3.csv|Success|ProcessResult|
+-------------------+-------------+--------------+-------+-------------+

How can I get such a dataset using the Spark API without using SQL?


You can use this code snippet to deduplicate rows using scala:

val dataframe= (... your dataframe ...)
val rankColumn = "rank"
val window     = Window.partitionBy(col("session_id"),col("key"),col("value"),col("process_name")).orderBy(col("insert_dttm").desc)
val deduplicatedDf = dataframe.withColumn(rankColumn, row_number over window).filter(col(rankColumn) === 1)

This might work:

import org.apache.spark.sql.functions.col

df.groupBy(
  df.columns
   .filterNot(z => z == "insert_dttm" || z == "session_id")
   .map(col(_)):_*)
  .agg(
    max(df("insert_dttm")).as("insert_dttm"), 
    max(df("session_id")).as("session_id"))

This is basically the same as doing this in SQL:

SELECT
  MAX(insert_dttm),
  MAX(session_id),
  <all the other columns>
GROUP BY
  <all the other columns>

No need for a window function, which is good to avoid if possible.