Fill in null with previously known good value with pyspark

Solution 1:

I believe I have a much simpler solution than the accepted. It is using Functions too, but uses the function called 'LAST' and ignores nulls.

Let's re-create something similar to the original data:

import sys
from pyspark.sql.window import Window
import pyspark.sql.functions as func

d = [{'session': 1, 'ts': 1}, {'session': 1, 'ts': 2, 'id': 109}, {'session': 1, 'ts': 3}, {'session': 1, 'ts': 4, 'id': 110}, {'session': 1, 'ts': 5},  {'session': 1, 'ts': 6}]
df = spark.createDataFrame(d)

This prints:

|session| ts|  id|
|      1|  1|null|
|      1|  2| 109|
|      1|  3|null|
|      1|  4| 110|
|      1|  5|null|
|      1|  6|null|

Now, if we use the window function LAST:

df.withColumn("id", func.last('id', True).over(Window.partitionBy('session').orderBy('ts').rowsBetween(-sys.maxsize, 0))).show()

We just get:

|session| ts|  id|
|      1|  1|null|
|      1|  2| 109|
|      1|  3| 109|
|      1|  4| 110|
|      1|  5| 110|
|      1|  6| 110|

Hope it helps!

Solution 2:

This seems to be doing the trick using Window functions:

import sys
from pyspark.sql.window import Window
import pyspark.sql.functions as func

def fill_nulls(df):
    df_na =
    lag = df_na.withColumn('id_lag', func.lag('id', default=-1)\

    switch = lag.withColumn('id_change',
                            ((lag['id'] != lag['id_lag']) &
                             (lag['id'] != -1)).cast('integer'))

    switch_sess = switch.withColumn(
            .rowsBetween(-sys.maxsize, 0))

    fid = switch_sess.withColumn('nn_id',
                           .over(Window.partitionBy('session', 'sub_session')\

    fid_na = fid.replace(-1, 'null')

    ff = fid_na.drop('id').drop('id_lag')\
                          withColumnRenamed('nn_id', 'id')

    return ff

Here is the full

Solution 3:

@Oleksiy's answer is great, but didn't fully work for my requirements. Within a session, if multiple nulls are observed, all are filled with the first non-null for the session. I needed the last non-null value to propagate forward.

The following tweak worked for my use case:

def fill_forward(df, id_column, key_column, fill_column):

    # Fill null's with last *non null* value in the window
    ff = df.withColumn(
        func.last(fill_column, True) # True: fill with last non-null
            .rowsBetween(-sys.maxsize, 0))

    # Drop the old column and rename the new column
    ff_out = ff.drop(fill_column).withColumnRenamed('fill_fwd', fill_column)

    return ff_out