【问题标题】:MATLAB - matrix multiply submatrices within a single matrixMATLAB - 矩阵乘以单个矩阵内的子矩阵
【发布时间】:2013-06-18 02:22:13
【问题描述】:

我正在尝试以“矢量化”方式将大 (2x2m) 矩阵的 (2x2) 子矩阵相乘,以消除 for 循环并提高速度。目前,我重塑为 (2x2xm) 然后使用 for 循环来执行此操作:

for n = 1:1e5
    m = 1e4;
    A = rand([2,2*m]);     % A is a function of n
    A = reshape(A,2,2,[]);
    B = eye(2);
    for i = 1:m
        B = A(:,:,i)*B;    % multiply the long chain of 2x2's
    end
end

函数目标类似于@prod,但使用矩阵乘法而不是逐元素标量乘法。 @multiprod 似乎很接近,但将两个不同的 nD 矩阵作为参数。我想象一个解决方案使用一个非常大的二维数组的多个子矩阵,或者一个 2x2m{xn} 数组来消除一个或两个 for 循环。

提前谢谢你,乔

【问题讨论】:

  • 听起来像是 bsxfun 的工作。
  • bsxfun 只允许 'times' (element-wise),而不是 'mtimes' (matrix) 作为函数参数
  • 如果A = [a1 a2 a3 ... am],那么使用B = A(:,:,i)*B将产生am*am-1*...*a2*a1,而使用B = B*A(:,:,i)将产生a1*a2*a3*...*am。矩阵乘积是不可交换的,因此这些结果通常是不同的。你要哪一个?
  • @RodyOldenhuis - 你是对的。我编码的是我正在寻找的东西,尽管重新排列 A 以满足需要应该很容易。

标签: matlab submatrix


【解决方案1】:

我认为您必须以不同的方式重塑矩阵才能进行矢量化乘法,如下面的代码所示。这段代码也使用了循环,但我认为应该更快

MM      = magic(2);
M0      = MM;
M1      = rot90(MM,1);
M2      = rot90(MM,2);
M3      = rot90(MM,3);


MBig1           = cat(2,M0,M1,M2,M3);
fprintf('Original matrix\n')
disp(MBig1)
MBig2           = zeros(size(MBig1,2));
MBig2(1:2,:)    = MBig1;
for k=0:3
    c1 =  k   *2+1;
    c2 = (k+1)*2+0;
    MBig2(:,c1:c2) = circshift(MBig2(:,c1:c2),[2*k 0]);
end
fprintf('Reshaped original matrix\n')
disp(MBig2)

fprintf('Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in direct way\n')
disp([ M0*M0 M0*M1 M0*M2 M0*M3 ])
fprintf('Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in vectorized way\n')
disp( kron(eye(4),M0)*MBig2 )


fprintf('Checking [ M0*M1*M2*M3 ] in direct way\n')
disp([ M0*M1*M2*M3 ])
fprintf('Checking [ M0*M1*M2*M3 ] in vectorized way\n')
R2 = MBig2;
for k=1:3
    R2 = R2 * circshift(MBig2,-[2 2]*k);
end
disp(R2)

输出是

Original matrix
     1     3     3     2     2     4     4     1
     4     2     1     4     3     1     2     3

Reshaped original matrix
     1     3     0     0     0     0     0     0
     4     2     0     0     0     0     0     0
     0     0     3     2     0     0     0     0
     0     0     1     4     0     0     0     0
     0     0     0     0     2     4     0     0
     0     0     0     0     3     1     0     0
     0     0     0     0     0     0     4     1
     0     0     0     0     0     0     2     3

Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in direct way
    13     9     6    14    11     7    10    10
    12    16    14    16    14    18    20    10

Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in vectorized way
    13     9     0     0     0     0     0     0
    12    16     0     0     0     0     0     0
     0     0     6    14     0     0     0     0
     0     0    14    16     0     0     0     0
     0     0     0     0    11     7     0     0
     0     0     0     0    14    18     0     0
     0     0     0     0     0     0    10    10
     0     0     0     0     0     0    20    10

Checking [ M0*M1*M2*M3 ] in direct way
   292   168
   448   292

Checking [ M0*M1*M2*M3 ] in vectorized way
   292   168     0     0     0     0     0     0
   448   292     0     0     0     0     0     0
     0     0   292   336     0     0     0     0
     0     0   224   292     0     0     0     0
     0     0     0     0   292   448     0     0
     0     0     0     0   168   292     0     0
     0     0     0     0     0     0   292   224
     0     0     0     0     0     0   336   292

【讨论】:

  • 1) 我喜欢将其改造成对角矩阵的概念。要创建 MBig2,'MBig2 = blkdiag(M0,M1,M2,M3)' 会做同样的事情吗? 2) 第一部分很好地乘以 M0*MBig。我相信我在网上找到的功能 multiprod 或 MTIMESX 会做到这一点;我将不得不比较速度。 3) 是'for k=1:3 R2 = R2 * circshift(MBig2,-[2 2]*k)的最终解; end'快于'for k = 1:3 B = A(:,:,k)*B;结束'?
【解决方案2】:

下面的函数可以解决我的部分问题。它被命名为“mprod”与 prod,类似于 times 与 mtimes。通过一些重塑,它递归地使用multiprod。通常,递归函数调用比循环慢。 Multiprod 声称要快 100 倍以上,所以它应该会补偿。

function sqMat = mprod(M)
    % Multiply *many* square matrices together, stored
    % as 3D array M. Speed gain through recursive use 
    % of function 'multiprod' (Leva, 2010).

    % check if M consists of multiple matrices
    if size(M,3) > 1
        % check for odd number of matrices
        if mod(size(M,3),2)
            siz = size(M,1);
            M = cat(3,M,eye(siz));
        end
        % create two smaller 3D arrays
        X = M(:,:,1:2:end); % odd pages
        Y = M(:,:,2:2:end); % even pages
        % recursive call
        sqMat = mprod(multiprod(X,Y));
    else
        % create final 2D matrix and break recursion
        sqMat = M(:,:,1);
    end
end

我没有测试过这个函数的速度或准确性。我相信这比循环快得多。它不会“矢量化”操作,因为它不能用于更高维度;此函数的任何重复使用都必须在循环内完成。

编辑 下面是新的代码,看起来运行得足够快。对函数的递归调用很慢并且会占用堆栈内存。仍然包含一个循环,但将循环数减少了 log(n)/log(2)。此外,还增加了对更多维度的支持。

function sqMats = mprod(M)
    % Multiply *many* square matrices together, stored along 3rd axis.
    % Extra dimensions are conserved; use 'permute' to change axes of "M".
    % Speed gained by recursive use of 'multiprod' (Leva, 2010).

    % save extra dimensions, then reshape
    dims = size(M);
    M = reshape(M,dims(1),dims(2),dims(3),[]);
    extraDim = size(M,4);

    % Check if M consists of multiple matrices...
    % split into two sets and multiply using multiprod, recursively
    siz = size(M,3);
    while siz > 1
        % check for odd number of matrices
        if mod(siz,2)
            addOn = repmat(eye(size(M,1)),[1,1,1,extraDim]);
            M = cat(3,M,addOn);
        end
        % create two smaller 3D arrays
        X = M(:,:,1:2:end,:); % odd pages
        Y = M(:,:,2:2:end,:); % even pages
        % recursive call and actual matrix multiplication
        M = multiprod(X,Y);
        siz = size(M,3);
    end

    % reshape to original dimensions, minus the third axis.
    dims(3) = [];
    sqMats = reshape(M,dims);
end

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2013-12-20
    • 1970-01-01
    • 2021-12-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多