Skewed dataset join in Spark?

I am joining two big datasets using Spark RDD. One dataset is very much skewed so few of the executor tasks taking a long time to finish the job. How can I solve this scenario?

Solution 1:

Pretty good article on how it can be done:

Short version:

  • Add random element to large RDD and create new join key with it
  • Add random element to small RDD using explode/flatMap to increase number of entries and create new join key
  • Join RDDs on new join key which will now be distributed better due to random seeding

Solution 2:

Say you have to join two tables A and B on Lets assume that table A has skew on id=1.

i.e. select from A join B on =

There are two basic approaches to solve the skew join issue:

Approach 1:

Break your query/dataset into 2 parts - one containing only skew and the other containing non skewed data. In the above example. query will become -

 1. select from A join B on = where <> 1;
 2. select from A join B on = where = 1 and = 1;

The first query will not have any skew, so all the tasks of ResultStage will finish at roughly the same time.

If we assume that B has only few rows with = 1, then it will fit into memory. So Second query will be converted to a broadcast join. This is also called Map-side join in Hive.


The partial results of the two queries can then be merged to get the final results.

Approach 2:

Also mentioned by LeMuBei above, the 2nd approach tries to randomize the join key by appending extra column. Steps:

  1. Add a column in the larger table (A), say skewLeft and populate it with random numbers between 0 to N-1 for all the rows.

  2. Add a column in the smaller table (B), say skewRight. Replicate the smaller table N times. So values in new skewRight column will vary from 0 to N-1 for each copy of original data. For this, you can use the explode sql/dataset operator.

After 1 and 2, join the 2 datasets/tables with join condition updated to-

                * = && A.skewLeft = B.skewRight*


Solution 3:

Depending on the particular kind of skew you're experiencing, there may be different ways to solve it. The basic idea is:

  • Modify your join column, or create a new join column, that is not skewed but which still retains adequate information to do the join
  • Do the join on that non-skewed column -- resulting partitions will not be skewed
  • Following the join, you can update the join column back to your preferred format, or drop it if you created a new column

The "Fighting the Skew In Spark" article referenced in LiMuBei's answer is a good technique if the skewed data participates in the join. In my case, skew was caused by a very large number of null values in the join column. The null values were not participating in the join, but since Spark partitions on the join column, the post-join partitions were very skewed as there was one gigantic partition containing all of the nulls.

I solved it by adding a new column which changed all null values to a well-distributed temporary value, such as "NULL_VALUE_X", where X is replaced by random numbers between say 1 and 10,000, e.g. (in Java):

// Before the join, create a join column with well-distributed temporary values for null swids.  This column
// will be dropped after the join.  We need to do this so the post-join partitions will be well-distributed,
// and not have a giant partition with all null swids.
String swidWithDistributedNulls = "swid_with_distributed_nulls";
int numNullValues = 10000; // Just use a number that will always be bigger than number of partitions
Column swidWithDistributedNullsCol =
    when(csDataset.col(CS_COL_SWID).isNull(), functions.concat(
csDataset = csDataset.withColumn(swidWithDistributedNulls, swidWithDistributedNullsCol);

Then joining on this new column, and then after the join:


Solution 4:

Taking reference from below is the code for fighting the skew in spark using Pyspark dataframe API

Creating the 2 dataframes:

from math import exp
from random import randint
from datetime import datetime

def count_elements(splitIndex, iterator):
    n = sum(1 for _ in iterator)
    yield (splitIndex, n)

def get_part_index(splitIndex, iterator):
    for it in iterator:
        yield (splitIndex, it)

num_parts = 18
# create the large skewed rdd
skewed_large_rdd = sc.parallelize(range(0,num_parts), num_parts).flatMap(lambda x: range(0, int(exp(x))))
skewed_large_rdd = skewed_large_rdd.mapPartitionsWithIndex(lambda ind, x: get_part_index(ind, x))

skewed_large_df = spark.createDataFrame(skewed_large_rdd,['x','y'])

small_rdd = sc.parallelize(range(0,num_parts), num_parts).map(lambda x: (x, x))

small_df = spark.createDataFrame(small_rdd,['a','b'])

Dividing the data into 100 bins for large df and replicating the small df 100 times

salt_bins = 100
from pyspark.sql import functions as F

skewed_transformed_df = skewed_large_df.withColumn('salt', (F.rand()*salt_bins).cast('int')).cache()

small_transformed_df = small_df.withColumn('replicate', F.array([F.lit(i) for i in range(salt_bins)]))

small_transformed_df ='*', F.explode('replicate').alias('salt')).drop('replicate').cache()

Finally the join avoiding the skew

t0 =
result2 = skewed_transformed_df.join(small_transformed_df, (skewed_transformed_df['x'] == small_transformed_df['a']) & (skewed_transformed_df['salt'] == small_transformed_df['salt']) )
print "The direct join takes %s"%(str( - t0))

Solution 5:

Apache DataFu has two methods for doing skewed joins that implement some of the suggestions in the previous answers.

The joinSkewed method does salting (adding a random number column to split the skewed values).

The broadcastJoinSkewed method is for when you can divide the dataframe into skewed and regular parts, as described in Approach 2 from the answer by moriarty007.

These methods in DataFu are useful for projects using Spark 2.x. If you are already on Spark 3, there are dedicated methods for doing skewed joins.

Full disclosure - I am a member of Apache DataFu.