Retrieve top n in each group of a DataFrame in pyspark
Solution 1:
I believe you need to use window functions to attain the rank of each row based on user_id
and score
, and subsequently filter your results to only keep the first two values.
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col
window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
df.select('*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
.show()
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|
#+-------+---------+-----+----+
In general, the official programming guide is a good place to start learning Spark.
Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
Solution 2:
Top-n is more accurate if using row_number
instead of rank
when getting rank equality:
val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
.where(col('row_number') <= n) \
.limit(20) \
.toPandas()
Note
limit(20).toPandas()
trick instead ofshow()
for Jupyter notebooks for nicer formatting.
Solution 3:
I know the question is asked for pyspark
and I was looking for the similar answer in Scala
i.e.
Retrieve top n values in each group of a DataFrame in Scala
Here is the scala
version of @mtoto's answer.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col
val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show()
# you can change the value 2 to any number you want. Here 2 represents the top 2 values
More examples can be found here.
Solution 4:
Here is another solution without a window function to get the top N records from pySpark DataFrame.
# Import Libraries
from pyspark.sql.functions import col
# Sample Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)
# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)
# Display DataFrame
sorted_df.show()
Output
+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2| 2|
| user_2| object_2| 2|
| user_1| object_1| 3|
| user_2| object_1| 5|
| user_2| object_2| 6|
+-------+---------+-----+
If you are interested in more window functions in Spark you can refer to one of my blogs: https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86