Adding a group count column to a PySpark dataframe

I am coming from R and the tidyverse to PySpark due to its superior Spark handling, and I am struggling to map certain concepts from one context to the other.

In particular, suppose that I had a dataset like the following

x | y
--+--
a | 5
a | 8
a | 7
b | 1

and I wanted to add a column containing the number of rows for each x value, like so:

x | y | n
--+---+---
a | 5 | 3
a | 8 | 3
a | 7 | 3
b | 1 | 1

In dplyr, I would just say:

import(tidyverse)

df <- read_csv("...")
df %>%
    group_by(x) %>%
    mutate(n = n()) %>%
    ungroup()

and that would be that. I can do something almost as simple in PySpark if I'm looking to summarize by number of rows:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") \
    .groupBy(col("x")) \
    .count() \
    .show()

And I thought I understood that withColumn was equivalent to dplyr's mutate. However, when I do the following, PySpark tells me that withColumn is not defined for groupBy data:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") \
    .groupBy(col("x")) \
    .withColumn("n", count("x")) \
    .show()

In the short run, I can simply create a second dataframe containing the counts and join it to the original dataframe. However, it seems like this could become inefficient in the case of large tables. What is the canonical way to accomplish this?


When you do a groupBy(), you have to specify the aggregation before you can display the results. For example:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
]
df = sqlCtx.createDataFrame(data, ["x", "y"])
df.groupBy('x').count().select('x', f.col('count').alias('n')).show()
#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+

Here I used alias() to rename the column. But this only returns one row per group. If you want all rows with the count appended, you can do this with a Window:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.select('x', 'y', f.count('x').over(w).alias('n')).sort('x', 'y').show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

Or if you're more comfortable with SQL, you can register the dataframe as a temporary table and take advantage of pyspark-sql to do the same thing:

df.registerTempTable('table')
sqlCtx.sql(
    'SELECT x, y, COUNT(x) OVER (PARTITION BY x) AS n FROM table ORDER BY x, y'
).show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

as @pault appendix

import pyspark.sql.functions as F

...

(df
.groupBy(F.col('x'))
.agg(F.count('x').alias('n'))
.show())

#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+

enjoy


I found we can get even more close to the tidyverse example:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.withColumn('n', f.count('x').over(w)).sort('x', 'y').show()