Using tensorflow.keras model in pyspark UDF generates a pickle error

I would like to use a tensorflow.keras model in a pysark pandas_udf. However, I get a pickle error when the model is being serialized before sending it to the workers. I am not sure I am using the best method to perform what I want, therefore I will expose a minimal but complete example.

Packages:

  • tensorflow-2.2.0 (but error is triggered to all previous versions too)
  • pyspark-2.4.5

The import statements are:

import pandas as pd
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

from pyspark.sql import SparkSession, functions as F, types as T

The Pyspark UDF is a pandas_udf:

def compute_output_pandas_udf(model):
    '''Spark pandas udf for model prediction.'''

    @F.pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
    def compute_output(inputs1, inputs2, inputs3):
        pdf = pd.DataFrame({
            'input1': inputs1,
            'input2': inputs2,
            'input3': inputs3
        })
        pdf['predicted_output'] = model.predict(pdf.values)
        return pdf['predicted_output']

    return compute_output

The main code:

# Model parameters
weights = np.array([[0.5], [0.4], [0.3]])
bias = np.array([1.25])
activation = 'linear'
input_dim, output_dim = weights.shape

# Initialize model
model = Sequential()
layer = Dense(output_dim, input_dim=input_dim, activation=activation)
model.add(layer)
layer.set_weights([weights, bias])

# Initialize Spark session
spark = SparkSession.builder.appName('test').getOrCreate()

# Create pandas df with inputs and run model
pdf = pd.DataFrame({
    'input1': np.random.randn(200),
    'input2': np.random.randn(200),
    'input3': np.random.randn(200)
})
pdf['predicted_output'] = model.predict(pdf[['input1', 'input2', 'input3']].values)

# Create spark df with inputs and run model using udf
sdf = spark.createDataFrame(pdf)
sdf = sdf.withColumn('predicted_output', compute_output_pandas_udf(model)('input1', 'input2', 'input3'))
sdf.limit(5).show()

This error is triggered when compute_output_pandas_udf(model) is called:

PicklingError: Could not serialize object: TypeError: can't pickle _thread.RLock objects

I found this page about pickling a keras model and tried it on tensorflow.keras but I got the following error when the predict function of the model is called in the UDF (so serialization worked but unserialization not?):

AttributeError: 'Sequential' object has no attribute '_distribution_strategy'

Anyone has an idea about how to proceed? Thank you in advance!

PS: Note that I did not use a model directly from keras library because I have another error appearing periodically and it seems more difficult to solve it. However, the serialization of the model does not generate an error as with the tensorflow.keras model.


So it looks like that if we use the solution to extend the getstate and setstate methods directly in the tensorflow.keras.models.Model class as in http://zachmoshe.com/2017/04/03/pickling-keras-models.html, then the workers are not able to unserialize the model as they don't have this extension of the class.

Then, the solution is to use a wrapper class as Erp12 suggested in this post.

class ModelWrapperPickable:

    def __init__(self, model):
        self.model = model

    def __getstate__(self):
        model_str = ''
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            tensorflow.keras.models.save_model(self.model, fd.name, overwrite=True)
            model_str = fd.read()
        d = { 'model_str': model_str }
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            self.model = tensorflow.keras.models.load_model(fd.name)

The UDF becomes:

def compute_output_pandas_udf(model_wrapper):
    '''Spark pandas udf for model prediction.'''

    @F.pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
    def compute_output(inputs1, inputs2, inputs3):
        pdf = pd.DataFrame({
            'input1': inputs1,
            'input2': inputs2,
            'input3': inputs3
        })
        pdf['predicted_output'] = model_wrapper.model.predict(pdf.values)
        return pdf['predicted_output']

    return compute_output

And the main code:

# Model parameters
weights = np.array([[0.5], [0.4], [0.3]])
bias = np.array([1.25])
activation = 'linear'
input_dim, output_dim = weights.shape

# Initialize keras model
model = Sequential()
layer = Dense(output_dim, input_dim=input_dim, activation=activation)
model.add(layer)
layer.set_weights([weights, bias])
# Initialize model wrapper
model_wrapper= ModelWrapperPickable(model)

# Initialize Spark session
spark = SparkSession.builder.appName('test').getOrCreate()

# Create pandas df with inputs and run model
pdf = pd.DataFrame({
    'input1': np.random.randn(200),
    'input2': np.random.randn(200),
    'input3': np.random.randn(200)
})
pdf['predicted_output'] = model_wrapper.model.predict(pdf[['input1', 'input2', 'input3']].values)

# Create spark df with inputs and run model using udf
sdf = spark.createDataFrame(pdf)
sdf = sdf.withColumn('predicted_output', compute_output_pandas_udf(model_wrapper)('input1', 'input2', 'input3'))
sdf.limit(5).show()