How to create a custom Estimator in PySpark
Generally speaking there is no documentation because as for Spark 1.6 / 2.0 most of the related API is not intended to be public. It should change in Spark 2.1.0 (see SPARK-7146).
API is relatively complex because it has to follow specific conventions in order to make given Transformer
or Estimator
compatible with Pipeline
API. Some of these methods may be required for features like reading and writing or grid search. Other, like keyword_only
are just a simple helpers and not strictly required.
Assuming you have defined following mix-ins for mean parameter:
from pyspark.ml.pipeline import Estimator, Model, Pipeline
from pyspark.ml.param.shared import *
from pyspark.sql.functions import avg, stddev_samp
class HasMean(Params):
mean = Param(Params._dummy(), "mean", "mean",
typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasMean, self).__init__()
def setMean(self, value):
return self._set(mean=value)
def getMean(self):
return self.getOrDefault(self.mean)
standard deviation parameter:
class HasStandardDeviation(Params):
standardDeviation = Param(Params._dummy(),
"standardDeviation", "standardDeviation",
typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasStandardDeviation, self).__init__()
def setStddev(self, value):
return self._set(standardDeviation=value)
def getStddev(self):
return self.getOrDefault(self.standardDeviation)
and threshold:
class HasCenteredThreshold(Params):
centeredThreshold = Param(Params._dummy(),
"centeredThreshold", "centeredThreshold",
typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasCenteredThreshold, self).__init__()
def setCenteredThreshold(self, value):
return self._set(centeredThreshold=value)
def getCenteredThreshold(self):
return self.getOrDefault(self.centeredThreshold)
you could create basic Estimator
as follows:
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark import keyword_only
class NormalDeviation(Estimator, HasInputCol,
HasPredictionCol, HasCenteredThreshold,
DefaultParamsReadable, DefaultParamsWritable):
@keyword_only
def __init__(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):
super(NormalDeviation, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
# Required in Spark >= 3.0
def setInputCol(self, value):
"""
Sets the value of :py:attr:`inputCol`.
"""
return self._set(inputCol=value)
# Required in Spark >= 3.0
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)
@keyword_only
def setParams(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _fit(self, dataset):
c = self.getInputCol()
mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first()
return NormalDeviationModel(
inputCol=c, mean=mu, standardDeviation=sigma,
centeredThreshold=self.getCenteredThreshold(),
predictionCol=self.getPredictionCol())
class NormalDeviationModel(Model, HasInputCol, HasPredictionCol,
HasMean, HasStandardDeviation, HasCenteredThreshold,
DefaultParamsReadable, DefaultParamsWritable):
@keyword_only
def __init__(self, inputCol=None, predictionCol=None,
mean=None, standardDeviation=None,
centeredThreshold=None):
super(NormalDeviationModel, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, predictionCol=None,
mean=None, standardDeviation=None,
centeredThreshold=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _transform(self, dataset):
x = self.getInputCol()
y = self.getPredictionCol()
threshold = self.getCenteredThreshold()
mu = self.getMean()
sigma = self.getStddev()
return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)
Credits to Benjamin-Manns for the use of DefaultParamsReadable, DefaultParamsWritable available in PySpark >= 2.3.0
Finally it could be used as follows:
df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])
normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)
model = Pipeline(stages=[normal_deviation]).fit(df)
model.transform(df).show()
## +---+----+----------+
## | id| x|prediction|
## +---+----+----------+
## | 1| 2.0| false|
## | 2| 3.0| false|
## | 3| 0.0| false|
## | 4|99.0| true|
## +---+----+----------+