Spark SQL broadcast hash join

I'm trying to perform a broadcast hash join on dataframes using SparkSQL as documented here: https://docs.cloud.databricks.com/docs/latest/databricks_guide/06%20Spark%20SQL%20%26%20DataFrames/05%20BroadcastHashJoin%20-%20scala.html

In that example, the (small) DataFrame is persisted via saveAsTable and then there's a join via spark SQL (i.e. via sqlContext.sql("..."))

The problem I have is that I need to use the sparkSQL API to construct my SQL (I am left joining ~50 tables with an ID list, and don't want to write the SQL by hand).

How do I tell spark to use the broadcast hash join via the API?  The issue is that if I load the ID list (from the table persisted via `saveAsTable`) into a `DataFrame` to use in the join, it isn't clear to me if Spark can apply the broadcast hash join.

You can explicitly mark the DataFrame as small enough for broadcasting using broadcast function:

Python:

from pyspark.sql.functions import broadcast

small_df = ...
large_df = ...

large_df.join(broadcast(small_df), ["foo"])

or broadcast hint (Spark >= 2.2):

large_df.join(small_df.hint("broadcast"), ["foo"])

Scala:

import org.apache.spark.sql.functions.broadcast

val smallDF: DataFrame = ???
val largeDF: DataFrame = ???

largeDF.join(broadcast(smallDF), Seq("foo"))

or broadcast hint (Spark >= 2.2):

largeDF.join(smallDF.hint("broadcast"), Seq("foo"))

SQL

You can use hints (Spark >= 2.2):

SELECT /*+ MAPJOIN(small) */ * 
FROM large JOIN small
ON large.foo = small.foo

or

SELECT /*+  BROADCASTJOIN(small) */ * 
FROM large JOIN small
ON large.foo = small.foo

or

SELECT /*+ BROADCAST(small) */ * 
FROM large JOIN small
ON larger.foo = small.foo

R (SparkR):

With hint (Spark >= 2.2):

join(large, hint(small, "broadcast"), large$foo == small$foo)

With broadcast (Spark >= 2.3)

join(large, broadcast(small), large$foo == small$foo)

Note:

Broadcast join is useful if one of structures is relatively small. Otherwise it can be significantly more expensive than a full shuffle.


jon_rdd = sqlContext.sql( "select * from people_in_india  p
                            join states s
                            on p.state = s.name")


jon_rdd.toDebugString() / join_rdd.explain() : 

shuffledHashJoin :
all the data for the India will be shuffled into only 29 keys for each of the states. Problems: uneven sharding. Limited parallelism with 29 output partitions.

broadcaseHashJoin:

broadcast the small RDD to all worker nodes. parallelism of the large rdd is still maintained and shuffle is not even required.enter image description here

PS: Image may ugly but informative.