【问题标题】:Fast multipliction of multiple matrices by multiple vectors多个矩阵与多个向量的快速乘法
【发布时间】:2017-09-15 20:47:10
【问题描述】:

在 matlab 中,我想用 L 个矩阵将 M 个向量相乘,得到 M x L 个新向量。具体来说,假设我有一个大小为 N x M 的矩阵 A 和一个大小为 N x N x L 矩阵的矩阵 B,我想计算一个大小为 N x M x L 的矩阵 C,其结果与下面的慢代码:

for m=1:M
    for l=1:L
         C(:,m,l)=B(:,:,l)*A(:,m)
    end
end

但要有效地实现这一点(使用本机代码而不是 matlab 循环)。

【问题讨论】:

    标签: matlab performance matrix matrix-multiplication


    【解决方案1】:

    我们可以ab-use fast matrix-multiplication 这里,只需要重新排列尺寸。因此,将B 的第二个维度推回末尾并重新整形为2D,以便合并前两个维度。使用A 执行矩阵乘法,得到一个二维数组。我们称之为C。现在,C's 第一个暗淡是来自B 的合并暗淡。因此,通过重新整形将其拆分回原来的两个暗淡长度,从而产生一个 3D 数组。最后再用一个permute 将第二个暗淡推到后面。这是所需的3D 输出。

    因此,实现将是 -

    permute(reshape(reshape(permute(B,[1,3,2]),[],N)*A,N,L,[]),[1,3,2])
    

    基准测试

    基准代码:

    % Setup inputs
    M = 150;
    L = 150;
    N = 150;
    A = randn(N,M);
    B = randn(N,N,L);
    
    disp('----------------------- ORIGINAL LOOPY -------------------')
    tic
    C_loop = NaN(N,M,L);
    for m=1:M
        for l=1:L
             C_loop(:,m,l)=B(:,:,l)*A(:,m);
        end
    end
    toc
    
    disp('----------------------- BSXFUN + PERMUTE -----------------')
    % @Luis's soln
    tic
    C = permute(sum(bsxfun(@times, permute(B, [1 2 4 3]), ...
                            permute(A, [3 1 2])), 2), [1 3 4 2]);
    toc
    
    disp('----------------------- BSXFUN + MATRIX-MULT -------------')
    % Propose in this post
    tic
    out = permute(reshape(reshape(permute(B,[1,3,2]),[],N)*A,N,L,[]),[1,3,2]);
    toc
    

    时间:

    ----------------------- ORIGINAL LOOPY -------------------
    Elapsed time is 0.905811 seconds.
    ----------------------- BSXFUN + PERMUTE -----------------
    Elapsed time is 0.883616 seconds.
    ----------------------- BSXFUN + MATRIX-MULT -------------
    Elapsed time is 0.045331 seconds.
    

    【讨论】:

    • 像往常一样,矩阵乘法击败bsxfun!干得好!!
    【解决方案2】:

    您可以通过一些维度的排列和单例扩展来做到这一点:

    C = permute(sum(bsxfun(@times, permute(B, [1 2 4 3]), permute(A, [3 1 2])), 2), [1 3 4 2]);
    

    检查:

    % Example inputs:
    M = 5;
    L = 6;
    N = 7;
    A = randn(N,M);
    B = randn(N,N,L);
    
    % Output with bsxfun and permute:    
    C = permute(sum(bsxfun(@times, permute(B, [1 2 4 3]), permute(A, [3 1 2])), 2), [1 3 4 2]);
    
    % Output with loops:
    C_loop = NaN(N,M,L);
    for m=1:M
        for l=1:L
             C_loop(:,m,l)=B(:,:,l)*A(:,m);
        end
    end
    
    % Maximum relative error. Should be 0, or of the order of eps:
    max_error = max(reshape(abs(C./C_loop),[],1)-1) 
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-02-28
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-01-25
      • 2021-12-22
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多