Data.table | более быстрое рекурсивное обновление по группе в группе

Мне нужно выполнить следующую рекурсивную строку за строкой, чтобы получить z:

myfun = function (xb, a, b) {

z = NULL

for (t in 1:length(xb)) {

    if (t >= 2) { a[t] = b[t-1] + xb[t] }
    z[t] = rnorm(1, mean = a[t])
    b[t] = a[t] + z[t]

}

return(z)

}

set.seed(1)

n_smpl = 1e6 
ni = 5

id = rep(1:n_smpl, each = ni)

smpl = data.table(id)
smpl[, time := 1:.N, by = id]

a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

smpl[, z := myfun(xb, a, b), by = id]

Я хотел бы получить такой результат:

      id time a b  xb            z
  1:   1    1 1 1   1    0.3735462
  2:   1    2 1 1   2    2.7470924
  3:   1    3 1 1   3    8.4941848
  4:   1    4 1 1   4   20.9883695
  5:   1    5 1 1   5   46.9767390
 ---                              
496: 100    1 1 1 100    0.3735462
497: 100    2 1 1 200  200.7470924
498: 100    3 1 1 300  701.4941848
499: 100    4 1 1 400 1802.9883695
500: 100    5 1 1 500 4105.9767390

Это работает, но требует времени:

system.time(smpl[, z := myfun(xb, a, b), by = id])
   user  system elapsed 
 33.646   0.994  34.473

Мне нужно сделать это быстрее, учитывая размер моих фактических данных (более 2 миллионов наблюдений). Я думаю, do.call(myfun, .SD), .SDcols = c('xb', 'a', 'b') с by = .(id, time) будет намного быстрее, избегая цикла for внутри myfun. Тем не менее, я не был уверен, как я могу обновить b и его отставание (возможно, используя shift), когда я запускаю эту строку за строкой в ​​data.table. Любые предложения?

Ответы

Ответ 1

Отличный вопрос!

Начиная с нового сеанса R, показывая демо-данные с 5 миллионами строк, вот ваша функция от вопроса и времени на моем ноутбуке. С некоторыми комментариями.

require(data.table)   # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

myfun = function (xb, a, b) {

  z = NULL
  # initializes a new length-0 variable

  for (t in 1:length(xb)) {

      if (t >= 2) { a[t] = b[t-1] + xb[t] }
      # if() on every iteration. t==1 could be done before loop

      z[t] = rnorm(1, mean = a[t])
      # z vector is grown by 1 item, each time

      b[t] = a[t] + z[t]
      # assigns to all of b vector when only really b[t-1] is
      # needed on the next iteration 
  }
  return(z)
}

set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
   user  system elapsed 
 19.216   0.004  19.212

smpl
              id time a b      xb            z
      1:       1    1 1 1       1 3.735462e-01
      2:       1    2 1 1       2 3.557190e+00
      3:       1    3 1 1       3 9.095107e+00
      4:       1    4 1 1       4 2.462112e+01
      5:       1    5 1 1       5 5.297647e+01
     ---                                      
4999996: 1000000    1 1 1 1000000 1.618913e+00
4999997: 1000000    2 1 1 2000000 2.000000e+06
4999998: 1000000    3 1 1 3000000 7.000003e+06
4999999: 1000000    4 1 1 4000000 1.800001e+07
5000000: 1000000    5 1 1 5000000 4.100001e+07

Итак, 19.2s - время бить. Во всех этих таймингах я запускаю команду 3 раза локально, чтобы убедиться, что это стабильное время. Временная дисперсия незначительна в этой задаче, поэтому я просто сообщаю одно время, чтобы сохранить ответ быстрее.

Решение комментариев, приведенных выше в myfun():

myfun2 = function (xb, a, b) {

  z = numeric(length(xb))
  # allocate size up front rather than growing

  z[1] = rnorm(1, mean=a[1])
  prevb = a[1]+z[1]
  t = 2L
  while(t<=length(xb)) {
    at = prevb + xb[t]
    z[t] = rnorm(1, mean=at)
    prevb = at + z[t]
    t = t+1L
  }
  return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
   user  system elapsed 
 13.212   0.036  13.245 
smpl[,identical(z,z2)]
[1] TRUE

Это было неплохо (19.2s до 13.2s), но это все еще цикл for на уровне R. На первый взгляд он не может быть векторизован, потому что вызов rnorm() зависит от предыдущего значения. Фактически, это, вероятно, можно векторизовать, используя свойство m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd) и вызывая векторизованное rnorm(n=5e6) раз, а не 5e6 раз. Но, вероятно, для участия в группах будет задействован cumsum(). Так что пусть не идет туда, потому что это, вероятно, сделает код более трудным для чтения и будет специфичным для этой точной проблемы.

Итак, попробуйте Rcpp, который очень похож на стиль, который вы написали, и более широко применим:

require(Rcpp)   # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
  NumericVector z = NumericVector(xb.length());
  z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
  double prevb = a[0]+z[0];
  int t = 1;
  while (t<xb.length()) {
    double at = prevb + xb[t];
    z[t] = R::rnorm(at, 1);
    prevb = at + z[t];
    t++;
  }
  return z;
}')

set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
   user  system elapsed 
  1.800   0.020   1.819 
smpl[,identical(z,z3)]
[1] TRUE

Гораздо лучше: 19.2s до 1.8s. Но каждый вызов функции вызывает первую строку (NumericVector()), которая выделяет новый вектор, если число строк в группе. Затем он заполняется и возвращается, который копируется в последний столбец в правильном месте для этой группы (на :=), только для того, чтобы быть выпущенным. Это распределение и управление всеми этими 1 миллионом небольших временных векторов (по одному для каждой группы) немного запутано.

Почему бы нам не сделать весь столбец за один раз? Вы уже написали это в стиле цикла, и в этом нет ничего плохого. Позвольте настроить функцию C, чтобы принять столбец id и добавить if, когда он достигнет новой группы.

cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {

  // ** id must be pre-grouped, such as via setkey(DT,id) **

  NumericVector z = NumericVector(id.length());
  int previd = id[0]-1;  // initialize to anything different than id[0]
  for (int i=0; i<id.length(); i++) {
    double prevb;
    if (id[i]!=previd) {
      // first row of new group
      z[i] = R::rnorm(a[i], 1);
      prevb = a[i]+z[i];
      previd = id[i];
    } else {
      // 2nd row of group onwards
      double at = prevb + xb[i];
      z[i] = R::rnorm(at, 1);
      prevb = at + z[i];
    }
  }
  return z;
}')

system.time(setkey(smpl,id))  # ensure grouped by id
   user  system elapsed
  0.028   0.004   0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
   user  system elapsed 
  0.232   0.004   0.237 
smpl[,identical(z,z4)]
[1] TRUE

Это лучше: 19.2s до 0.27s.