поиск по сетке scikit по нескольким классификаторам

Я хотел знать, есть ли более эффективный способ создания сетки и тестирования нескольких моделей в одном конвейере. Разумеется, параметры моделей были бы разными, поэтому мне сложно это понять. Вот что я сделал:

from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.grid_search import GridSearchCV


def grid_search():
    pipeline1 = Pipeline((
    ('clf', RandomForestClassifier()),
    ('vec2', TfidfTransformer())
    ))

    pipeline2 = Pipeline((
    ('clf', KNeighborsClassifier()),
    ))

    pipeline3 = Pipeline((
    ('clf', SVC()),
    ))

    pipeline4 = Pipeline((
    ('clf', MultinomialNB()),
    ))

    parameters1 = {
    'clf__n_estimators': [10, 20, 30],
    'clf__criterion': ['gini', 'entropy'],
    'clf__max_features': [5, 10, 15],
    'clf__max_depth': ['auto', 'log2', 'sqrt', None]
    }

    parameters2 = {
    'clf__n_neighbors': [3, 7, 10],
    'clf__weights': ['uniform', 'distance']
    }

    parameters3 = {
    'clf__C': [0.01, 0.1, 1.0],
    'clf__kernel': ['rbf', 'poly'],
    'clf__gamma': [0.01, 0.1, 1.0],

    }
    parameters4 = {
    'clf__alpha': [0.01, 0.1, 1.0]
    }

    pars = [parameters1, parameters2, parameters3, parameters4]
    pips = [pipeline1, pipeline2, pipeline3, pipeline4]

    print "starting Gridsearch"
    for i in range(len(pars)):
        gs = GridSearchCV(pips[i], pars[i], verbose=2, refit=False, n_jobs=-1)
        gs = gs.fit(X_train, y_train)
        print "finished Gridsearch"
        print gs.best_score_

Однако этот подход все еще дает лучшую модель в каждом классификаторе, а не сравнение между классификаторами.

Ответы

Ответ 1

Хотя тема немного устарела, я выкладываю ответ на случай, если он кому-нибудь поможет в будущем.

Вместо использования Grid Search для выбора гиперпараметра вы можете использовать библиотеку 'hyperopt'.

Пожалуйста, ознакомьтесь с разделом 2.2 этой страницы. В приведенном выше случае вы можете использовать выражение "hp.choice" для выбора среди различных конвейеров, а затем определить выражения параметров для каждого из них в отдельности.

В вашей целевой функции вам необходимо выполнить проверку в зависимости от выбранного конвейера и вернуть оценку CV для выбранного конвейера и параметров (возможно, через cross_cal_score).

Объект испытаний в конце выполнения покажет лучший конвейер и параметры в целом.

Ответ 3

Хотя решение от dubek более прямолинейно, оно не помогает взаимодействовать между параметрами элементов конвейера, которые предшествуют classfier. Поэтому я написал вспомогательный класс для работы с ним, и его можно включить в настройку конвейера по умолчанию для scikit. Минимальный пример:

from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler, MaxAbsScaler
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from pipelinehelper import PipelineHelper

iris = datasets.load_iris()
X_iris = iris.data
y_iris = iris.target
pipe = Pipeline([
    ('scaler', PipelineHelper([
        ('std', StandardScaler()),
        ('max', MaxAbsScaler()),
    ])),
    ('classifier', PipelineHelper([
        ('svm', LinearSVC()),
        ('rf', RandomForestClassifier()),
    ])),
])

params = {
    'scaler__selected_model': pipe.named_steps['scaler'].generate({
        'std__with_mean': [True, False],
        'std__with_std': [True, False],
        'max__copy': [True],  # just for displaying
    }),
    'classifier__selected_model': pipe.named_steps['classifier'].generate({
        'svm__C': [0.1, 1.0],
        'rf__n_estimators': [100, 20],
    })
}
grid = GridSearchCV(pipe, params, scoring='accuracy', verbose=1)
grid.fit(X_iris, y_iris)
print(grid.best_params_)
print(grid.best_score_)

Его также можно использовать для других элементов конвейера, а не только для классификатора. Код на GitHub, если кто-то хочет проверить это.