Dividing complex rows of dataframe to simple rows in Pyspark
I have this code:
from pyspark import SparkContext
from pyspark.sql import SQLContext, Row
sc = SparkContext()
sqlContext = SQLContext(sc)
documents = sqlContext.createDataFrame([
Row(id=1, title=[Row(value=u'cars', max_dist=1000)]),
Row(id=2, title=[Row(value=u'horse bus',max_dist=50), Row(value=u'normal bus',max_dist=100)]),
Row(id=3, title=[Row(value=u'Airplane', max_dist=5000)]),
Row(id=4, title=[Row(value=u'Bicycles', max_dist=20),Row(value=u'Motorbikes', max_dist=80)]),
Row(id=5, title=[Row(value=u'Trams', max_dist=15)])])
documents.show(truncate=False)
#+---+----------------------------------+
#|id |title |
#+---+----------------------------------+
#|1 |[[1000,cars]] |
#|2 |[[50,horse bus], [100,normal bus]]|
#|3 |[[5000,Airplane]] |
#|4 |[[20,Bicycles], [80,Motorbikes]] |
#|5 |[[15,Trams]] |
#+---+----------------------------------+
I need to split all compound rows (e.g. 2 & 4) to multiple rows while retaining the 'id', to get a result like this:
#+---+----------------------------------+
#|id |title |
#+---+----------------------------------+
#|1 |[1000,cars] |
#|2 |[50,horse bus] |
#|2 |[100,normal bus] |
#|3 |[5000,Airplane] |
#|4 |[20,Bicycles] |
#|4 |[80,Motorbikes] |
#|5 |[15,Trams] |
#+---+----------------------------------+
Solution 1:
Just explode
it:
from pyspark.sql.functions import explode
documents.withColumn("title", explode("title"))
## +---+----------------+
## | id| title|
## +---+----------------+
## | 1| [1000,cars]|
## | 2| [50,horse bus]|
## | 2|[100,normal bus]|
## | 3| [5000,Airplane]|
## | 4| [20,Bicycles]|
## | 4| [80,Motorbikes]|
## | 5| [15,Trams]|
## +---+----------------+
Solution 2:
Ok, here is what I've come up with. Unfortunately, I had to leave the world of Row
objects and enter the world of list
objects because I couldn't find a way to append
to a Row
object.
That means this method is bit messy. If you can find a way to add a new column to a Row
object, then this is NOT the way to go.
def add_id(row):
it_list = []
for i in range(0, len(row[1])):
sm_list = []
for j in row[1][i]:
sm_list.append(j)
sm_list.append(row[0])
it_list.append(sm_list)
return it_list
with_id = documents.flatMap(lambda x: add_id(x))
df = with_id.map(lambda x: Row(id=x[2], title=Row(value=x[0], max_dist=x[1]))).toDF()
When I run df.show()
, I get:
+---+----------------+
| id| title|
+---+----------------+
| 1| [cars,1000]|
| 2| [horse bus,50]|
| 2|[normal bus,100]|
| 3| [Airplane,5000]|
| 4| [Bicycles,20]|
| 4| [Motorbikes,80]|
| 5| [Trams,15]|
+---+----------------+