Renaming columns for PySpark DataFrame aggregates

Solution 1:

Although I still prefer dplyr syntax, this code snippet will do:

import pyspark.sql.functions as sf

(df.groupBy("group")
   .agg(sf.sum('money').alias('money'))
   .show(100))

It gets verbose.

Solution 2:

withColumnRenamed should do the trick. Here is the link to the pyspark.sql API.

df.groupBy("group")\
  .agg({"money":"sum"})\
  .withColumnRenamed("SUM(money)", "money")
  .show(100)

Solution 3:

It's simple as:

 val maxVideoLenPerItemDf = requiredItemsFiltered.groupBy("itemId").agg(max("playBackDuration").as("customVideoLength"))
maxVideoLenPerItemDf.show()

Use .as in agg to name the new row created.

Solution 4:

I made a little helper function for this that might help some people out.

import re

from functools import partial

def rename_cols(agg_df, ignore_first_n=1):
    """changes the default spark aggregate names `avg(colname)` 
    to something a bit more useful. Pass an aggregated dataframe
    and the number of aggregation columns to ignore.
    """
    delimiters = "(", ")"
    split_pattern = '|'.join(map(re.escape, delimiters))
    splitter = partial(re.split, split_pattern)
    split_agg = lambda x: '_'.join(splitter(x))[0:-ignore_first_n]
    renamed = map(split_agg, agg_df.columns[ignore_first_n:])
    renamed = zip(agg_df.columns[ignore_first_n:], renamed)
    for old, new in renamed:
        agg_df = agg_df.withColumnRenamed(old, new)
    return agg_df

An example:

gb = (df.selectExpr("id", "rank", "rate", "price", "clicks")
 .groupby("id")
 .agg({"rank": "mean",
       "*": "count",
       "rate": "mean", 
       "price": "mean", 
       "clicks": "mean", 
       })
)

>>> gb.columns
['id',
 'avg(rate)',
 'count(1)',
 'avg(price)',
 'avg(rank)',
 'avg(clicks)']

>>> rename_cols(gb).columns
['id',
 'avg_rate',
 'count_1',
 'avg_price',
 'avg_rank',
 'avg_clicks']

Doing at least a bit to save people from typing so much.

Solution 5:

df = df.groupby('Device_ID').agg(aggregate_methods)
for column in df.columns:
    start_index = column.find('(')
    end_index = column.find(')')
    if (start_index and end_index):
        df = df.withColumnRenamed(column, column[start_index+1:end_index])

The above code can strip out anything that is outside of the "()". For example, "sum(foo)" will be renamed as "foo".