How to split a list to multiple columns in Pyspark?

Solution 1:

It depends on the type of your "list":

  • If it is of type ArrayType():

    df = hc.createDataFrame(sc.parallelize([['a', [1,2,3]], ['b', [2,3,4]]]), ["key", "value"])
    df.printSchema()
    df.show()
    root
     |-- key: string (nullable = true)
     |-- value: array (nullable = true)
     |    |-- element: long (containsNull = true)
    

    you can access the values like you would with python using []:

    df.select("key", df.value[0], df.value[1], df.value[2]).show()
    +---+--------+--------+--------+
    |key|value[0]|value[1]|value[2]|
    +---+--------+--------+--------+
    |  a|       1|       2|       3|
    |  b|       2|       3|       4|
    +---+--------+--------+--------+
    
    +---+-------+
    |key|  value|
    +---+-------+
    |  a|[1,2,3]|
    |  b|[2,3,4]|
    +---+-------+
    
  • If it is of type StructType(): (maybe you built your dataframe by reading a JSON)

    df2 = df.select("key", psf.struct(
            df.value[0].alias("value1"), 
            df.value[1].alias("value2"), 
            df.value[2].alias("value3")
        ).alias("value"))
    df2.printSchema()
    df2.show()
    root
     |-- key: string (nullable = true)
     |-- value: struct (nullable = false)
     |    |-- value1: long (nullable = true)
     |    |-- value2: long (nullable = true)
     |    |-- value3: long (nullable = true)
    
    +---+-------+
    |key|  value|
    +---+-------+
    |  a|[1,2,3]|
    |  b|[2,3,4]|
    +---+-------+
    

    you can directly 'split' the column using *:

    df2.select('key', 'value.*').show()
    +---+------+------+------+
    |key|value1|value2|value3|
    +---+------+------+------+
    |  a|     1|     2|     3|
    |  b|     2|     3|     4|
    +---+------+------+------+
    

Solution 2:

I'd like to add the case of sized lists (arrays) to pault answer.

In the case that our column contains medium sized arrays (or large sized ones) it is still possible to split them in columns.

from pyspark.sql.types import *          # Needed to define DataFrame Schema.
from pyspark.sql.functions import expr   

# Define schema to create DataFrame with an array typed column.
mySchema = StructType([StructField("V1", StringType(), True),
                       StructField("V2", ArrayType(IntegerType(),True))])

df = spark.createDataFrame([['A', [1, 2, 3, 4, 5, 6, 7]], 
                            ['B', [8, 7, 6, 5, 4, 3, 2]]], schema= mySchema)

# Split list into columns using 'expr()' in a comprehension list.
arr_size = 7
df = df.select(['V1', 'V2']+[expr('V2[' + str(x) + ']') for x in range(0, arr_size)])

# It is posible to define new column names.
new_colnames = ['V1', 'V2'] + ['val_' + str(i) for i in range(0, arr_size)] 
df = df.toDF(*new_colnames)

The result is:

df.show(truncate= False)

+---+---------------------+-----+-----+-----+-----+-----+-----+-----+
|V1 |V2                   |val_0|val_1|val_2|val_3|val_4|val_5|val_6|
+---+---------------------+-----+-----+-----+-----+-----+-----+-----+
|A  |[1, 2, 3, 4, 5, 6, 7]|1    |2    |3    |4    |5    |6    |7    |
|B  |[8, 7, 6, 5, 4, 3, 2]|8    |7    |6    |5    |4    |3    |2    |
+---+---------------------+-----+-----+-----+-----+-----+-----+-----+

Solution 3:

@jordi Aceiton thanks for the solution. I tried to make it more concise, tried to remove the loop for renaming the newly created column names, doing it while creating the columns. Using df.columns to fetch all the column names rather creating it manually.

    from pyspark.sql.types import *          
    from pyspark.sql.functions import * 
    from pyspark import Row

    df = spark.createDataFrame([Row(index=1, finalArray = [1.1,2.3,7.5], c =4),Row(index=2, finalArray = [9.6,4.1,5.4], c= 4)])
    #collecting all the column names as list
    dlist = df.columns
    #Appending new columns to the dataframe
    df.select(dlist+[(col("finalArray")[x]).alias("Value"+str(x+1)) for x in range(0, 3)]).show()

Output:

     +---------------+-----+------+------+------+
     |  finalArray   |index|Value1|Value2|Value3|
     +---------------+-----+------+------+------+
     |[1.1, 2.3, 7.5]|  1  |   1.1|   2.3|   7.5|
     |[9.6, 4.1, 5.4]|  2  |   9.6|   4.1|   5.4|
     +---------------+-----+------+------+------+

Solution 4:

I needed to unlist a 712 dimensional array into columns in order to write it to csv. I used @MaFF's solution first for my problem but that seemed to cause a lot of errors and additional computation time. I am not sure what was causing it, but I used a different method which reduced the computation time considerably (22 minutes compared to more than 4 hours)!

Method by @MaFF's:

length = len(dataset.head()["list_col"])
dataset = dataset.select(dataset.columns + [dataset["list_col"][k] for k in range(length)])

What I used:

dataset = dataset.rdd.map(lambda x: (*x, *x["list_col"])).toDF()

If someone has any ideas what was causing this difference in computational time, please let me know! I suspect that in my case the bottleneck was with calling head() to get the list length (which I would like be be adaptive). And because (i) my data pipeline was quite long and exhaustive, and (ii) I had to unlist multiple columns. Furthermore caching the entire dataset was not an option.