How do I unit test PySpark programs?

Solution 1:

I'd recommend using py.test as well. py.test makes it easy to create re-usable SparkContext test fixtures and use it to write concise test functions. You can also specialize fixtures (to create a StreamingContext for example) and use one or more of them in your tests.

I wrote a blog post on Medium on this topic:

https://engblog.nextdoor.com/unit-testing-apache-spark-with-py-test-3b8970dc013b

Here is a snippet from the post:

pytestmark = pytest.mark.usefixtures("spark_context")
def test_do_word_counts(spark_context):
    """ test word couting
    Args:
       spark_context: test fixture SparkContext
    """
    test_input = [
        ' hello spark ',
        ' hello again spark spark'
    ]

    input_rdd = spark_context.parallelize(test_input, 1)
    results = wordcount.do_word_counts(input_rdd)

    expected_results = {'hello':2, 'spark':3, 'again':1}  
    assert results == expected_results

Solution 2:

Here's a solution with pytest if you're using Spark 2.x and SparkSession. I'm also importing a third party package.

import logging

import pytest
from pyspark.sql import SparkSession

def quiet_py4j():
    """Suppress spark logging for the test context."""
    logger = logging.getLogger('py4j')
    logger.setLevel(logging.WARN)


@pytest.fixture(scope="session")
def spark_session(request):
    """Fixture for creating a spark context."""

    spark = (SparkSession
             .builder
             .master('local[2]')
             .config('spark.jars.packages', 'com.databricks:spark-avro_2.11:3.0.1')
             .appName('pytest-pyspark-local-testing')
             .enableHiveSupport()
             .getOrCreate())
    request.addfinalizer(lambda: spark.stop())

    quiet_py4j()
    return spark


def test_my_app(spark_session):
   ...

Note if using Python 3, I had to specify that as a PYSPARK_PYTHON environment variable:

import os
import sys

IS_PY2 = sys.version_info < (3,)

if not IS_PY2:
    os.environ['PYSPARK_PYTHON'] = 'python3'

Otherwise you get the error:

Exception: Python in worker has different version 2.7 than that in driver 3.5, PySpark cannot run with different minor versions.Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set.

Solution 3:

Assuming you have pyspark installed, you can use the class below for unitTest it in unittest:

import unittest
import pyspark


class PySparkTestCase(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        conf = pyspark.SparkConf().setMaster("local[2]").setAppName("testing")
        cls.sc = pyspark.SparkContext(conf=conf)
        cls.spark = pyspark.SQLContext(cls.sc)

    @classmethod
    def tearDownClass(cls):
        cls.sc.stop()

Example:

class SimpleTestCase(PySparkTestCase):

    def test_with_rdd(self):
        test_input = [
            ' hello spark ',
            ' hello again spark spark'
        ]

        input_rdd = self.sc.parallelize(test_input, 1)

        from operator import add

        results = input_rdd.flatMap(lambda x: x.split()).map(lambda x: (x, 1)).reduceByKey(add).collect()
        self.assertEqual(results, [('hello', 2), ('spark', 3), ('again', 1)])

    def test_with_df(self):
        df = self.spark.createDataFrame(data=[[1, 'a'], [2, 'b']], 
                                        schema=['c1', 'c2'])
        self.assertEqual(df.count(), 2)

Note that this creates a context per class. Use setUp instead of setUpClass to get a context per test. This typically adds a lot of overhead time on the execution of the tests, as creating a new spark context is currently expensive.