How to split a dataframe into dataframes with same column values?

You can collect unique state values and simply map over resulting array:

val states = df.select("State").distinct.collect.flatMap(_.toSeq)
val byStateArray = states.map(state => df.where($"State" <=> state))

or to map:

val byStateMap = states
    .map(state => (state -> df.where($"State" <=> state)))
    .toMap

The same thing in Python:

from itertools import chain
from pyspark.sql.functions import col

states = chain(*df.select("state").distinct().collect())

# PySpark 2.3 and later
# In 2.2 and before col("state") == state) 
# should give the same outcome, ignoring NULLs 
# if NULLs are important 
# (lit(state).isNull() & col("state").isNull()) | (col("state") == state)
df_by_state = {state: 
  df.where(col("state").eqNullSafe(state)) for state in states}

The obvious problem here is that it requires a full data scan for each level, so it is an expensive operation. If you're looking for a way to just split the output see also How do I split an RDD into two or more RDDs?

In particular you can write Dataset partitioned by the column of interest:

val path: String = ???
df.write.partitionBy("State").parquet(path)

and read back if needed:

// Depend on partition prunning
for { state <- states } yield spark.read.parquet(path).where($"State" === state)

// or explicitly read the partition
for { state <- states } yield spark.read.parquet(s"$path/State=$state")

Depending on the size of the data, number of levels of the splitting, storag and persistence level of the input it might faster or slower than multiple filters.


It is very simple (if the spark version is 2) if you make the dataframe as a temporary table.

df1.createOrReplaceTempView("df1")

And now you can do the queries,

var df2 = spark.sql("select * from df1 where state = 'FL'")
var df3 = spark.sql("select * from df1 where state = 'MN'")
var df4 = spark.sql("select * from df1 where state = 'AL'")

Now you got the df2, df3, df4. If you want to have them as list, you can use,

df2.collect()
df3.collect()

or even map/filter function. Please refer https://spark.apache.org/docs/latest/sql-programming-guide.html#datasets-and-dataframes

Ash