Быстрая линейная регрессия по группам

У меня есть 500 тыс. пользователей, и мне нужно вычислить линейную регрессию (с перехватом) для каждого из них.

Каждый пользователь имеет около 30 записей.

Я пробовал с dplyr и lm, и это слишком медленно. Пользователь в течение 2 секунд.

  df%>%                       
      group_by(user_id, add =  FALSE) %>%
      do(lm = lm(Y ~ x, data = .)) %>%
      mutate(lm_b0 = summary(lm)$coeff[1],
             lm_b1 = summary(lm)$coeff[2]) %>%
      select(user_id, lm_b0, lm_b1) %>%
      ungroup()
    )

Я попытался использовать lm.fit, который, как известно, работает быстрее, но он не кажется совместимым с dplyr.

Есть ли быстрый способ сделать линейную регрессию по группе?

Ответы

Ответ 1

Вы можете просто использовать основные формулы для расчета наклона и регрессии. lm делает много ненужных вещей, если все, о чем вы заботитесь, это эти два числа. Здесь я использую data.table для агрегации, но вы можете сделать это и в базе R (или dplyr):

system.time(
  res <- DT[, 
    {
      ux <- mean(x)
      uy <- mean(y)
      slope <- sum((x - ux) * (y - uy)) / sum((x - ux) ^ 2)
      list(slope=slope, intercept=uy - slope * ux)
    }, by=user.id
  ]
)

Производит для пользователей 500K ~ 30 обс каждый (в секундах):

 user  system elapsed 
 7.35    0.00    7.36 

Или о 15 микросекунд на пользователя. И подтвердить это работает как ожидалось:

> summary(DT[user.id==89663, lm(y ~ x)])$coefficients
             Estimate Std. Error   t value  Pr(>|t|)
(Intercept) 0.1965844  0.2927617 0.6714826 0.5065868
x           0.2021210  0.5429594 0.3722580 0.7120808
> res[user.id == 89663]
   user.id    slope intercept
1:   89663 0.202121 0.1965844

Данные:

set.seed(1)
users <- 5e5
records <- 30
x <- runif(users * records)
DT <- data.table(
  x=x, y=x + runif(users * records) * 4 - 2, 
  user.id=sample(users, users * records, replace=T)
)

Ответ 2

Если все, что вы хотите, это коэффициенты, я бы просто использовал user_id как фактор в регрессии. Использование симулированного кода данных @miles2know (хотя переименование, поскольку объект, отличный от exp(), разделяет это имя, выглядит странным для меня)

dat <- data.frame(id = rep(c("a","b","c"), each = 20),
                  x = rnorm(60,5,1.5),
                  y = rnorm(60,2,.2))

mod = lm(y ~ x:id + id + 0, data = dat)

Мы не подходим к глобальному перехвату (+ 0), так что перехват для каждого id является коэффициентом id, а не x сам по себе, так что взаимодействия x:id являются наклонами для каждого id

coef(mod)
#      ida      idb      idc    x:ida    x:idb    x:idc 
# 1.779686 1.893582 1.946069 0.039625 0.033318 0.000353 

Таким образом, для уровня a id коэффициент ida, 1,78, является перехватом, а коэффициент x:ida, 0.0396, является наклоном.

Я оставлю сбор этих коэффициентов в соответствующие столбцы кадра данных для вас...

Это решение должно быть очень быстрым, потому что вам не нужно иметь дело с подмножествами кадров данных. Вероятно, его можно было бы ускорить с помощью fastLm или такого.

Примечание по масштабируемости:

Я просто попробовал это на @nrussell, смоделировал полноразмерные данные и столкнулся с проблемами распределения памяти. В зависимости от того, сколько памяти у вас есть, оно может не работать за один раз, но вы, вероятно, можете сделать это в партиях идентификаторов пользователей. Некоторая комбинация его ответа и моего ответа может быть самым быстрым в целом - или nrussell может быть быстрее - расширение идентификатора пользователя в тысячи фиктивных переменных может быть неэффективным с точки зрения вычислительной мощности, поскольку я ожидал больше, чем пару минут для запуска всего 5000 идентификаторов пользователей.

Ответ 3

Update: Как отметил Дирк, мой оригинальный подход может быть значительно улучшен, если напрямую указать x и Y, а не использовать интерфейс на основе формул fastLm, который несет (довольно значительную) служебную нагрузку. Для сравнения, используя оригинальный набор данных полного размера,

R> system.time({
  dt[,c("lm_b0", "lm_b1") := as.list(
    unname(fastLm(x, Y)$coefficients))
    ,by = "user_id"]
})
#  user  system elapsed 
#55.364   0.014  55.401 
##
R> system.time({
  dt[,c("lm_b0","lm_b1") := as.list(
    unname(fastLm(Y ~ x, data=.SD)$coefficients))
    ,by = "user_id"]
})
#   user  system elapsed 
#356.604   0.047 356.820

это простое изменение дает примерно 6.5x speedup.


[Исходный подход]

Вероятно, есть кое-что для улучшения, но следующее заняло около 25 минут на Linux VM (процессор с тактовой частотой 2,6 ГГц), работающем под 64-разрядным R:

library(data.table)
library(RcppArmadillo)
##
dt[
  ,c("lm_b0","lm_b1") := as.list(
    unname(fastLm(Y ~ x, data=.SD)$coefficients)),
  by=user_id]
##
R> dt[c(1:2, 31:32, 61:62),]
   user_id   x         Y     lm_b0    lm_b1
1:       1 1.0 1674.8316 -202.0066 744.6252
2:       1 1.5  369.8608 -202.0066 744.6252
3:       2 1.0  463.7460 -144.2961 374.1995
4:       2 1.5  412.7422 -144.2961 374.1995
5:       3 1.0  513.0996  217.6442 261.0022
6:       3 1.5 1140.2766  217.6442 261.0022

Данные:

dt <- data.table(
  user_id = rep(1:500000,each=30))
##
dt[, x := seq(1, by=.5, length.out=30), by = user_id]
dt[, Y := 1000*runif(1)*x, by = user_id]
dt[, Y := Y + rnorm(
  30, 
  mean = sample(c(-.05,0,0.5)*mean(Y),1), 
  sd = mean(Y)*.25), 
  by = user_id]

Ответ 4

Вы можете попробовать попробовать с помощью таблицы данных. Я только что создал некоторые данные о игрушке, но я бы предположил, что data.table даст некоторое улучшение. Это довольно быстро. Но это довольно большой набор данных, поэтому, возможно, сравните этот подход с меньшим образцом, чтобы узнать, намного ли лучше скорость. удачи.


    library(data.table)

    exp <- data.table(id = rep(c("a","b","c"), each = 20), x = rnorm(60,5,1.5), y = rnorm(60,2,.2))
    # edit: it might also help to set a key on id with such a large data-set
    # with the toy example it would make no diff of course
    exp <- setkey(exp,id)
    # the nuts and bolts of the data.table part of the answer
    result <- exp[, as.list(coef(lm(y ~ x))), by=id]
    result
       id (Intercept)            x
    1:  a    2.013548 -0.008175644
    2:  b    2.084167 -0.010023549
    3:  c    1.907410  0.015823088