Функция GBM R: получать переменную значимость отдельно для каждого класса
Я использую функцию gbm в R (пакет gbm) для установки моделей ускорения ускорения градиента для классификации многоклассов. Я просто пытаюсь получить значение каждого предиктора отдельно для каждого класса, как на этом рисунке из Hastie book (стр. 382).
![enter image description here]()
Однако функция summary.gbm возвращает только общую важность предикторов (их значение усредняется по всем классы).
Кто-нибудь знает, как получить значения относительной значимости?
Ответы
Ответ 1
Я думаю, что короткий ответ таков: на странице 379 Хасти упоминает, что он использует MART, который, по-видимому, доступен только для Splus.
Я согласен, что пакет gbm, похоже, не позволяет увидеть отдельное относительное влияние. Если вам что-то нужно для проблемы с mutliclass, возможно, вы получите что-то очень похожее, построив gbm one-vs-all для каждого из ваших классов, а затем получив значения важности от каждой из этих моделей.
Итак, ваши классы - это a, b, c и d. Вы моделируете против остальных и получаете значение от этой модели. Затем вы моделируете b против остальных и получаете значение от этой модели. Etc.
Ответ 2
Надеюсь, эта функция поможет вам. В качестве примера я использовал данные из пакета ElemStatLearn. Функция определяет, что представляют собой классы для столбца, разбивает данные на эти классы, запускает функцию gbm() для каждого класса и выставляет графики для этих моделей.
# install.packages("ElemStatLearn"); install.packages("gbm")
library(ElemStatLearn)
library(gbm)
set.seed(137531)
# formula: the formula to pass to gbm()
# data: the data set to use
# column: the class column to use
classPlots <- function (formula, data, column) {
class_column <- as.character(data[,column])
class_values <- names(table(class_column))
class_indexes <- sapply(class_values, function(x) which(class_column == x))
split_data <- lapply(class_indexes, function(x) marketing[x,])
object <- lapply(split_data, function(x) gbm(formula, data = x))
rel.inf <- lapply(object, function(x) summary.gbm(x, plotit=FALSE))
nobjs <- length(class_values)
for( i in 1:nobjs ) {
tmp <- rel.inf[[i]]
tmp.names <- row.names(tmp)
tmp <- tmp$rel.inf
names(tmp) <- tmp.names
barplot(tmp, horiz=TRUE, col='red',
xlab="Relative importance", main=paste0("Class = ", class_values[i]))
}
rel.inf
}
par(mfrow=c(1,2))
classPlots(Income ~ Marital + Age, data = marketing, column = 2)
`
![output]()