Почему это наивное матричное умножение быстрее, чем базовое R?
В R матричное умножение очень оптимизировано, т.е. на самом деле это просто вызов BLAS/LAPACK. Тем не менее, я удивлен, что этот наивный C++ код для умножения матричных векторов кажется на 30% быстрее.
library(Rcpp)
# Simple C++ code for matrix multiplication
mm_code =
"NumericVector my_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
double v_j;
for(int j = 0; j < nCol; j++){
v_j = v[j];
for(int i = 0; i < nRow; i++){
ans[i] += m(i,j) * v_j;
}
}
return(ans);
}
"
# Compiling
my_mm = cppFunction(code = mm_code)
# Simulating data to use
nRow = 10^4
nCol = 10^4
m = matrix(rnorm(nRow * nCol), nrow = nRow)
v = rnorm(nCol)
system.time(my_ans <- my_mm(m, v))
#> user system elapsed
#> 0.103 0.001 0.103
system.time(r_ans <- m %*% v)
#> user system elapsed
#> 0.154 0.001 0.154
# Double checking answer is correct
max(abs(my_ans - r_ans))
#> [1] 0
Занимает ли база R %*%
определенный тип проверки данных, которую я пропускаю?
EDIT: Поняв, что происходит (спасибо SO!), Стоит отметить, что это худший сценарий для R %*%
, т.е. Матрица по вектору. Например, @RalfStubner отметил, что использование RcppArmadillo даже быстрее, чем наивная реализация, но практически идентично для матричной матрицы умножить (когда обе матрицы большие и квадратные):
arma_code <-
"arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
return m * m2;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
nRow = 10^3
nCol = 10^3
mat1 = matrix(rnorm(nRow * nCol),
nrow = nRow)
mat2 = matrix(rnorm(nRow * nCol),
nrow = nRow)
system.time(arma_mm(mat1, mat2))
#> user system elapsed
#> 0.798 0.008 0.814
system.time(mat1 %*% mat2)
#> user system elapsed
#> 0.807 0.005 0.822
Таким образом, ток R (v3.5.0) %*%
близок к оптимальному для матричной матрицы, но может быть значительно ускорен для матричного вектора, если вы в порядке пропускаете проверку.
Ответы
Ответ 1
Быстрый взгляд на names.c
(здесь в частности) указывает на do_matprod
, функцию C, вызываемую %*%
и найденную в файле array.c
. (Интересно, оказывается, что и crossprod
и tcrossprod
отправляются на эту же функцию). Вот ссылка на код do_matprod
.
Прокручивая функцию, вы можете видеть, что она позаботится о нескольких вещах, которых ваша наивная реализация не включает, в том числе:
- Сохраняет имена строк и столбцов, где это имеет смысл.
- Позволяет отправлять альтернативные методы S4, когда два объекта, которыми управляет вызов
%*%
имеют классы, для которых были предоставлены такие методы. (Что происходит в этой части функции.) - Обрабатывает как реальные, так и сложные матрицы.
- Реализует ряд правил для обработки умножения матрицы и матрицы, вектора и матрицы, матрицы и вектора, а также вектора и вектора. (Напомним, что при кросс-умножении в R вектор на LHS рассматривается как вектор строки, тогда как на RHS он рассматривается как вектор-столбец, это код, который делает это.)
В конце функции она отправляется либо в matprod
либо в cmatprod
. Интересно (по крайней мере, для меня), в случае реальных матриц, если любая матрица может содержать значения NaN
или Inf
, тогда matprod
отправляет (здесь) в функцию simple_matprod
которая примерно такая же простая и простая, как ваша собственная. В противном случае он отправляет одну из двух подпрограмм BLAS Fortran, которые, предположительно, быстрее, если можно гарантировать равномерное "хорошее поведение" матричных элементов.
Ответ 2
Ответ Джоша объясняет, почему умножение матрицы R не так быстро, как этот наивный подход. Мне было любопытно узнать, сколько можно получить, используя RcppArmadillo. Код достаточно прост:
arma_code <-
"arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
return m * v;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
Ориентир:
> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 71.23347 75.22364 90.13766 96.88279 98.07348 98.50182 10
m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751 10
arma_mm(m, v) 41.13348 41.42314 41.89311 41.81979 42.39311 42.78396 10
Таким образом, RcppArmadillo дает нам лучший синтаксис и лучшую производительность.
Любопытство стало лучше меня. Вот решение для непосредственного использования BLAS:
blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
char trans = 'N';
double one = 1.0, zero = 0.0;
int ione = 1;
F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
&ione, &zero, ans.begin(), &ione);
return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")
Ориентир:
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 72.61298 75.40050 89.75529 96.04413 96.59283 98.29938 10
m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572 10
arma_mm(m, v) 41.06718 41.70331 42.62366 42.47320 43.22625 45.19704 10
blas_mm(m, v) 41.58618 42.14718 42.89853 42.68584 43.39182 44.46577 10
Armadillo и BLAS (OpenBLAS в моем случае) почти одинаковы. И код BLAS - это то, что делает R в конце. Таким образом, 2/3 того, что R делает проверку ошибок и т.д.