Ответ 1
Отличный вопрос!
Начиная с нового сеанса R, показывая демо-данные с 5 миллионами строк, вот ваша функция от вопроса и времени на моем ноутбуке. С некоторыми комментариями.
require(data.table) # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
myfun = function (xb, a, b) {
z = NULL
# initializes a new length-0 variable
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
# if() on every iteration. t==1 could be done before loop
z[t] = rnorm(1, mean = a[t])
# z vector is grown by 1 item, each time
b[t] = a[t] + z[t]
# assigns to all of b vector when only really b[t-1] is
# needed on the next iteration
}
return(z)
}
set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
user system elapsed
19.216 0.004 19.212
smpl
id time a b xb z
1: 1 1 1 1 1 3.735462e-01
2: 1 2 1 1 2 3.557190e+00
3: 1 3 1 1 3 9.095107e+00
4: 1 4 1 1 4 2.462112e+01
5: 1 5 1 1 5 5.297647e+01
---
4999996: 1000000 1 1 1 1000000 1.618913e+00
4999997: 1000000 2 1 1 2000000 2.000000e+06
4999998: 1000000 3 1 1 3000000 7.000003e+06
4999999: 1000000 4 1 1 4000000 1.800001e+07
5000000: 1000000 5 1 1 5000000 4.100001e+07
Итак, 19.2s - время бить. Во всех этих таймингах я запускаю команду 3 раза локально, чтобы убедиться, что это стабильное время. Временная дисперсия незначительна в этой задаче, поэтому я просто сообщаю одно время, чтобы сохранить ответ быстрее.
Решение комментариев, приведенных выше в myfun()
:
myfun2 = function (xb, a, b) {
z = numeric(length(xb))
# allocate size up front rather than growing
z[1] = rnorm(1, mean=a[1])
prevb = a[1]+z[1]
t = 2L
while(t<=length(xb)) {
at = prevb + xb[t]
z[t] = rnorm(1, mean=at)
prevb = at + z[t]
t = t+1L
}
return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
user system elapsed
13.212 0.036 13.245
smpl[,identical(z,z2)]
[1] TRUE
Это было неплохо (19.2s до 13.2s), но это все еще цикл for
на уровне R. На первый взгляд он не может быть векторизован, потому что вызов rnorm()
зависит от предыдущего значения. Фактически, это, вероятно, можно векторизовать, используя свойство m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd)
и вызывая векторизованное rnorm(n=5e6)
раз, а не 5e6 раз. Но, вероятно, для участия в группах будет задействован cumsum()
. Так что пусть не идет туда, потому что это, вероятно, сделает код более трудным для чтения и будет специфичным для этой точной проблемы.
Итак, попробуйте Rcpp, который очень похож на стиль, который вы написали, и более широко применим:
require(Rcpp) # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
NumericVector z = NumericVector(xb.length());
z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
double prevb = a[0]+z[0];
int t = 1;
while (t<xb.length()) {
double at = prevb + xb[t];
z[t] = R::rnorm(at, 1);
prevb = at + z[t];
t++;
}
return z;
}')
set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
user system elapsed
1.800 0.020 1.819
smpl[,identical(z,z3)]
[1] TRUE
Гораздо лучше: 19.2s до 1.8s. Но каждый вызов функции вызывает первую строку (NumericVector()
), которая выделяет новый вектор, если число строк в группе. Затем он заполняется и возвращается, который копируется в последний столбец в правильном месте для этой группы (на :=
), только для того, чтобы быть выпущенным. Это распределение и управление всеми этими 1 миллионом небольших временных векторов (по одному для каждой группы) немного запутано.
Почему бы нам не сделать весь столбец за один раз? Вы уже написали это в стиле цикла, и в этом нет ничего плохого. Позвольте настроить функцию C, чтобы принять столбец id
и добавить if
, когда он достигнет новой группы.
cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {
// ** id must be pre-grouped, such as via setkey(DT,id) **
NumericVector z = NumericVector(id.length());
int previd = id[0]-1; // initialize to anything different than id[0]
for (int i=0; i<id.length(); i++) {
double prevb;
if (id[i]!=previd) {
// first row of new group
z[i] = R::rnorm(a[i], 1);
prevb = a[i]+z[i];
previd = id[i];
} else {
// 2nd row of group onwards
double at = prevb + xb[i];
z[i] = R::rnorm(at, 1);
prevb = at + z[i];
}
}
return z;
}')
system.time(setkey(smpl,id)) # ensure grouped by id
user system elapsed
0.028 0.004 0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
user system elapsed
0.232 0.004 0.237
smpl[,identical(z,z4)]
[1] TRUE
Это лучше: 19.2s до 0.27s.