Matlab/octave - Обобщенное матричное умножение
Я хотел бы сделать функцию для обобщения умножения матрицы. В принципе, он должен иметь возможность выполнять стандартное умножение матрицы, но он должен позволять изменять два бинарных оператора product/sum любой другой функцией.
Цель должна быть максимально эффективной, как с точки зрения процессора, так и с точки зрения памяти. Конечно, он всегда будет менее эффективным, чем A * B, но гибкость операторов здесь.
Вот несколько команд, которые я мог придумать после прочтения различных интересных темы:
A = randi(10, 2, 3);
B = randi(10, 3, 4);
% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))
% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)
% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd
Проблема с методами 1-3 состоит в том, что они будут генерировать n матриц, прежде чем сворачивать их с помощью sum(). 4 лучше, потому что он выполняет сумму() внутри bsxfun, но bsxfun все еще генерирует n матриц (за исключением того, что они в основном пусты, содержащие только вектор значений non-zeros, являющихся суммами, остальное заполняется 0, чтобы соответствовать требования к размерам).
Что бы я хотел, это что-то вроде 4-го метода, но без бесполезной памяти 0.
Любая идея?
Ответы
Ответ 1
Вот немного более отполированная версия решения, которое вы опубликовали, с небольшими улучшениями.
Мы проверяем, есть ли у нас больше строк, чем столбцы или наоборот, и затем умножаем соответственно, выбирая либо умножать строки с матрицами или матрицами со столбцами (тем самым делая наименьшее количество итераций цикла).
![A*B]()
Примечание. Это может быть не всегда лучшая стратегия (переход по строкам вместо столбцов), даже если число строк меньше, чем столбцов; тот факт, что массивы MATLAB хранятся в порядковый номер столбца в памяти, делает его более эффективным для среза по столбцам, поскольку элементы хранятся последовательно, В то время как доступ к строкам включает в себя перемещение элементов strides (что не кэширует - думаю пространственная локальность).
Кроме этого, код должен обрабатывать двойной/одиночный, реальный/сложный, полный/разреженный (и ошибки, где это не возможная комбинация). Он также учитывает пустые матрицы и нулевые размеры.
function C = my_mtimes(A, B, outFcn, inFcn)
% default arguments
if nargin < 4, inFcn = @times; end
if nargin < 3, outFcn = @sum; end
% check valid input
assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
'Expecting function handles.')
% preallocate output matrix
M = size(A,1);
N = size(B,2);
if issparse(A)
args = {'like',A};
elseif issparse(B)
args = {'like',B};
else
args = {superiorfloat(A,B)};
end
C = zeros(M,N, args{:});
% compute matrix multiplication
% http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
if M < N
% concatenation of products of row vectors with matrices
% A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
for m=1:M
%C(m,:) = A(m,:) * B;
%C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
end
else
% concatenation of products of matrices with column vectors
% A*B = [A*b_1 , A*b_2 , ... , A*b_n]
for n=1:N
%C(:,n) = A * B(:,n);
%C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
end
end
end
Сравнение
Функция, без сомнения, медленная повсюду, но для больших размеров она на порядок хуже, чем встроенное матричное умножение:
(tic/toc times in seconds)
(tested in R2014a on Windows 8)
size mtimes my_mtimes
____ __________ _________
400 0.0026398 0.20282
600 0.012039 0.68471
800 0.014571 1.6922
1000 0.026645 3.5107
2000 0.20204 28.76
4000 1.5578 221.51
![mtimes_vs_mymtimes]()
Вот тестовый код:
sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
n = sz(i); disp(n)
A = rand(n,n);
B = rand(n,n);
tic
C = A*B;
t(i,1) = toc;
tic
D = my_mtimes(A,B);
t(i,2) = toc;
assert(norm(C-D) < 1e-6)
clear A B C D
end
semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight
Дополнительные
Для полноты ниже приведены еще два наивных способа реализации обобщенного умножения матрицы (если вы хотите сравнить производительность, замените последнюю часть функции my_mtimes
любым из них). Я даже не собираюсь публиковать свои прошедшие времена:)
C = zeros(M,N, args{:});
for m=1:M
for n=1:N
%C(m,n) = A(m,:) * B(:,n);
%C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
end
end
И еще один способ (с тройным циклом):
C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
for n=1:N
for p=1:P
%C(m,n) = C(m,n) + A(m,p)*B(p,n);
%C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
end
end
end
Что делать дальше?
Если вы хотите выжать больше производительности, вам нужно перейти в MEX файл C/С++, чтобы сократить накладные расходы на интерпретируемый код MATLAB. Вы можете использовать оптимизированные процедуры BLAS/LAPACK, вызвав их из MEX файлов (см. вторую часть этого сообщения для примера). MATLAB поставляется с библиотекой Intel MKL, которую, честно говоря, вы не можете победить, когда речь заходит о вычислениях линейной алгебры на процессорах Intel.
Другие уже упомянули пару заявок на Файловый обмен, которые реализуют универсальные матричные процедуры как MEX файлы (см. @natan ответ). Это особенно эффективно, если вы связываете их с оптимизированной библиотекой BLAS.
Ответ 2
Почему бы просто не использовать bsxfun
возможность принимать произвольную функцию?
C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1);
Здесь
-
f
- внешняя функция (соответствующая сумме в случае матричного умножения). Он должен принять трехмерный массив произвольного размера m
x n
x p
и работать по его столбцам, чтобы вернуть массив 1
x m
x p
.
-
g
- это внутренняя функция (соответствующая произведению в случае матричного умножения). В соответствии с bsxfun
он должен принимать в качестве входных данных либо два вектора столбца того же размера, либо один вектор-столбец и один скаляр, и возвращать в качестве вывода вектор-столбец того же размера, что и вход (-ы).
Это работает в Matlab. Я не тестировал в Octave.
Пример 1: Матричное умножение:
>> f = @sum; %// outer function: sum
>> g = @times; %// inner function: product
>> A = [1 2 3; 4 5 6];
>> B = [10 11; -12 -13; 14 15];
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
28 30
64 69
Check:
>> A*B
ans =
28 30
64 69
Пример 2. Рассмотрим две приведенные выше матрицы с
>> f = @(x,y) sum(abs(x)); %// outer function: sum of absolute values
>> g = @(x,y) max(x./y, y./x); %// inner function: "symmetric" ratio
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
14.8333 16.1538
5.2500 5.6346
Проверить: вручную вычислить C(1,2)
:
>> sum(abs( max( (A(1,:))./(B(:,2)).', (B(:,2)).'./(A(1,:)) ) ))
ans =
16.1538
Ответ 3
Без погружения в детали есть такие инструменты, как mtimesx и MMX, которые являются быстрыми операциями общей матрицы и скалярных операций общего назначения. Вы можете изучить их код и адаптировать их к вашим потребностям.
Скорее всего, это будет быстрее, чем matlab bsxfun.
Ответ 4
После изучения нескольких функций обработки, таких как bsxfun, кажется, что невозможно использовать прямое матричное умножение с помощью этих (то, что я подразумеваю под прямым, состоит в том, что временные продукты не хранятся в памяти, а суммируются как ASAP, а затем другие потому что они имеют выход фиксированного размера (либо тот же, что и вход, либо с расширением singles bsxfun, декартовым произведением размеров двух входов). Однако можно немного обмануть Octave (что не работает с MatLab, который проверяет размеры вывода):
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', sparse(1, size(A,1)))
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', zeros(1, size(A,1), 2))(:,:,2)
Однако не используйте их, потому что полученные значения не являются надежными (Octave может отменить или даже удалить их и вернуть 0!).
Итак, теперь я просто реализую полу-векторную версию, здесь моя функция:
function C = genmtimes(A, B, outop, inop)
% C = genmtimes(A, B, inop, outop)
% Generalized matrix multiplication between A and B. By default, standard sum-of-products matrix multiplication is operated, but you can change the two operators (inop being the element-wise product and outop the sum).
% Speed note: about 100-200x slower than A*A' and about 3x slower when A is sparse, so use this function only if you want to use a different set of inop/outop than the standard matrix multiplication.
if ~exist('inop', 'var')
inop = @times;
end
if ~exist('outop', 'var')
outop = @sum;
end
[n, m] = size(A);
[m2, o] = size(B);
if m2 ~= m
error('nonconformant arguments (op1 is %ix%i, op2 is %ix%i)\n', n, m, m2, o);
end
C = [];
if issparse(A) || issparse(B)
C = sparse(o,n);
else
C = zeros(o,n);
end
A = A';
for i=1:n
C(:,i) = outop(bsxfun(inop, A(:,i), B))';
end
C = C';
end
Протестировано как с разреженными, так и с нормальными матрицами: разрыв производительности намного меньше с разреженными матрицами (в 3 раза медленнее), чем с нормальными матрицами (~ 100x медленнее).
Я думаю, что это медленнее, чем реализация bsxfun, но, по крайней мере, он не переполняет память:
A = randi(10, 1000);
C = genmtimes(A, A');
Если кто-нибудь лучше предложить, я все равно ищу лучшую альтернативу.