Как свернуть собственный оценщик в PySpark mllib

Я пытаюсь создать простой пользовательский Estimator в PySpark MLlib. Я имею здесь, что можно написать собственный Transformer, но я не уверен, как это сделать в Estimator. Я также не понимаю, что делает @keyword_only, и зачем мне так много сеттеров и геттеров. У Scikit-learn есть соответствующий документ для пользовательских моделей (см. Здесь, но PySpark этого не делает.

Псевдокод примерной модели:

class NormalDeviation():
    def __init__(self, threshold = 3):
    def fit(x, y=None):
       self.model = {'mean': x.mean(), 'std': x.std()]
    def predict(x):
       return ((x-self.model['mean']) > self.threshold * self.model['std'])
    def decision_function(x): # does ml-lib support this?

Ответы

Ответ 1

Вообще говоря, нет документации, так как для Spark 1.6/2.0 большинство связанных API не предназначены для публичного доступа. Он должен измениться в Spark 2.1.0 (см. SPARK-7146).

API

является относительно сложным, поскольку он должен следовать конкретным соглашениям, чтобы сделать данный Transformer или Estimator совместимым с Pipeline API. Некоторые из этих методов могут потребоваться для таких функций, как чтение и запись или поиск сетки. Другие, например keyword_only, являются просто помощниками и не требуются строго.

Предполагая, что вы определили следующие миксы для среднего параметра:

from pyspark.ml.pipeline import Estimator, Model, Pipeline
from pyspark.ml.param.shared import *
from pyspark.sql.functions import avg, stddev_samp


class HasMean(Params):

    mean = Param(Params._dummy(), "mean", "mean", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasMean, self).__init__()

    def setMean(self, value):
        return self._set(mean=value)

    def getMean(self):
        return self.getOrDefault(self.mean)

параметр стандартного отклонения:

class HasStandardDeviation(Params):

    stddev = Param(Params._dummy(), "stddev", "stddev", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasStandardDeviation, self).__init__()

    def setStddev(self, value):
        return self._set(stddev=value)

    def getStddev(self):
        return self.getOrDefault(self.stddev)

и порог:

class HasCenteredThreshold(Params):

    centered_threshold = Param(Params._dummy(),
            "centered_threshold", "centered_threshold",
            typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasCenteredThreshold, self).__init__()

    def setCenteredThreshold(self, value):
        return self._set(centered_threshold=value)

    def getCenteredThreshold(self):
        return self.getOrDefault(self.centered_threshold)

вы можете создать базовый Estimator следующим образом:

class NormalDeviation(Estimator, HasInputCol, 
        HasPredictionCol, HasCenteredThreshold):

    def _fit(self, dataset):
        c = self.getInputCol()
        mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first()
        return (NormalDeviationModel()
            .setInputCol(c)
            .setMean(mu)
            .setStddev(sigma)
            .setCenteredThreshold(self.getCenteredThreshold())
            .setPredictionCol(self.getPredictionCol()))

class NormalDeviationModel(Model, HasInputCol, HasPredictionCol,
        HasMean, HasStandardDeviation, HasCenteredThreshold):

    def _transform(self, dataset):
        x = self.getInputCol()
        y = self.getPredictionCol()
        threshold = self.getCenteredThreshold()
        mu = self.getMean()
        sigma = self.getStddev()

        return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)

Наконец, его можно использовать следующим образом:

df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])

normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)
model  = Pipeline(stages=[normal_deviation]).fit(df)

model.transform(df).show()
## +---+----+----------+
## | id|   x|prediction|
## +---+----+----------+
## |  1| 2.0|     false|
## |  2| 3.0|     false|
## |  3| 0.0|     false|
## |  4|99.0|      true|
## +---+----+----------+