Split Spark Dataframe string column into multiple columns
pyspark.sql.functions.split()
is the right approach here - you simply need to flatten the nested ArrayType column into multiple top-level columns. In this case, where each array only contains 2 items, it's very easy. You simply use Column.getItem()
to retrieve each part of the array as a column itself:
split_col = pyspark.sql.functions.split(df['my_str_col'], '-')
df = df.withColumn('NAME1', split_col.getItem(0))
df = df.withColumn('NAME2', split_col.getItem(1))
The result will be:
col1 | my_str_col | NAME1 | NAME2
-----+------------+-------+------
18 | 856-yygrm | 856 | yygrm
201 | 777-psgdg | 777 | psgdg
I am not sure how I would solve this in a general case where the nested arrays were not the same size from Row to Row.
Here's a solution to the general case that doesn't involve needing to know the length of the array ahead of time, using collect
, or using udf
s. Unfortunately this only works for spark
version 2.1 and above, because it requires the posexplode
function.
Suppose you had the following DataFrame:
df = spark.createDataFrame(
[
[1, 'A, B, C, D'],
[2, 'E, F, G'],
[3, 'H, I'],
[4, 'J']
]
, ["num", "letters"]
)
df.show()
#+---+----------+
#|num| letters|
#+---+----------+
#| 1|A, B, C, D|
#| 2| E, F, G|
#| 3| H, I|
#| 4| J|
#+---+----------+
Split the letters
column and then use posexplode
to explode the resultant array along with the position in the array. Next use pyspark.sql.functions.expr
to grab the element at index pos
in this array.
import pyspark.sql.functions as f
df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)\
.show()
#+---+------------+---+---+
#|num| letters|pos|val|
#+---+------------+---+---+
#| 1|[A, B, C, D]| 0| A|
#| 1|[A, B, C, D]| 1| B|
#| 1|[A, B, C, D]| 2| C|
#| 1|[A, B, C, D]| 3| D|
#| 2| [E, F, G]| 0| E|
#| 2| [E, F, G]| 1| F|
#| 2| [E, F, G]| 2| G|
#| 3| [H, I]| 0| H|
#| 3| [H, I]| 1| I|
#| 4| [J]| 0| J|
#+---+------------+---+---+
Now we create two new columns from this result. First one is the name of our new column, which will be a concatenation of letter
and the index in the array. The second column will be the value at the corresponding index in the array. We get the latter by exploiting the functionality of pyspark.sql.functions.expr
which allows us use column values as parameters.
df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)\
.drop("val")\
.select(
"num",
f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
f.expr("letters[pos]").alias("val")
)\
.show()
#+---+-------+---+
#|num| name|val|
#+---+-------+---+
#| 1|letter0| A|
#| 1|letter1| B|
#| 1|letter2| C|
#| 1|letter3| D|
#| 2|letter0| E|
#| 2|letter1| F|
#| 2|letter2| G|
#| 3|letter0| H|
#| 3|letter1| I|
#| 4|letter0| J|
#+---+-------+---+
Now we can just groupBy
the num
and pivot
the DataFrame. Putting that all together, we get:
df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)\
.drop("val")\
.select(
"num",
f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
f.expr("letters[pos]").alias("val")
)\
.groupBy("num").pivot("name").agg(f.first("val"))\
.show()
#+---+-------+-------+-------+-------+
#|num|letter0|letter1|letter2|letter3|
#+---+-------+-------+-------+-------+
#| 1| A| B| C| D|
#| 3| H| I| null| null|
#| 2| E| F| G| null|
#| 4| J| null| null| null|
#+---+-------+-------+-------+-------+
Here's another approach, in case you want split a string with a delimiter.
import pyspark.sql.functions as f
df = spark.createDataFrame([("1:a:2001",),("2:b:2002",),("3:c:2003",)],["value"])
df.show()
+--------+
| value|
+--------+
|1:a:2001|
|2:b:2002|
|3:c:2003|
+--------+
df_split = df.select(f.split(df.value,":")).rdd.flatMap(
lambda x: x).toDF(schema=["col1","col2","col3"])
df_split.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
| 1| a|2001|
| 2| b|2002|
| 3| c|2003|
+----+----+----+
I don't think this transition back and forth to RDDs is going to slow you down... Also don't worry about last schema specification: it's optional, you can avoid it generalizing the solution to data with unknown column size.