【问题标题】:How to speed up my code [includes example] in Matlab?如何在 Matlab 中加快我的代码 [包括示例]?
【发布时间】:2016-03-11 17:38:31
【问题描述】:

我想加快我的代码速度。目前,我使用 if 语句来做到这一点。但是,如果我们使用convolution way,它可以编写更快的代码。但是,它仅适用于简单情况(作为成对邻域)。让我们定义我的问题。

我有一个矩阵I=[1 1 1;2 2 2;2 2 1],它有两个标签{1,2}。我将填充添加为其右侧。对于I 中的每个像素,我们可以定义成对或三元组邻域。我们将根据规则“如果这些邻域值与像素具有相同的类别,则设置成本值等于-beta,否则设置成本等于beta”。

例如,让我们考虑上图中的黄色像素。它的标签是 2。我们需要计算可能的邻域情况下的总成本值,如最右侧所示。有趣像素的值将从标签 {1,2} 设置。在上图中。我只展示了将黄色像素设置为 1 的第一种情况。我们可以有相同的数字,但在下一种情况下设置黄色像素为 2。我的任务是根据上述规则计算成本函数。

这是我的代码。但是,它使用 if 语句。太慢了。你能帮我加快速度吗?我尝试使用卷积方式,但我不知道如何为三邻域定义掩码。谢谢大家

function U=compute_gibbs(Imlabel,beta,num_class)
num_class=2;
Imlabel=[1 1 1;2 2 2;2 2 1]
beta=1;
U=zeros([size(Imlabel) num_class]);
Imlabel = padarray(Imlabel,[1 1],'replicate','both');
[row,col] = size(Imlabel);
for ii = 2:row-1        
    for jj = 2:col-1
        for l = 1:num_class
            U(ii-1,jj-1,l)=GibbsEnergy(Imlabel,ii,jj,l,beta);
        end
    end
end
function energy = GibbsEnergy(img,i,j,label,beta)
    % img is the labeled image
    energy = 0;
    if (label == img(i,j)) energy = energy-beta;
        else energy = energy+beta;end        
    % North, south, east and west
    if (label == img(i-1,j)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i,j-1)) energy = energy-beta;
        else energy = energy+beta;end
    % diagonal elements
    if (label == img(i-1,j-1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i-1,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j-1)) energy = energy-beta;
        else energy = energy+beta;end
     %% Triangle elements
    % Case a 
    if(label==img(i-1,j)&label==img(i-1,j-1)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end
    if(label==img(i,j+1)&&label==img(i+1 ,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
    % Case b 
    if(label==img(i-1,j-1)&label==img(i,j-1)) energy = energy-beta;
        else energy = energy+beta;end     
     if(label==img(i-1,j)&label==img(i ,j+1)) energy = energy-beta;
        else energy = energy+beta;end  
     if(label==img(i+1,j)&label==img(i+1,j+1)) energy = energy-beta;
         else energy = energy+beta;end  
    % Case c   
    if(label==img(i,j-1)&label==img(i+1,j-1)) energy = energy-beta;
         else energy = energy+beta;end  
    if(label==img(i+1,j)&label==img(i,j+1)) energy = energy-beta;
         else energy = energy+beta;end  
    if(label==img(i-1 ,j)&label==img(i-1,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
    % Case d 
    if(label==img(i,j-1)&label==img(i-1,j)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i-1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 

    %% Rectangular
    if(label==img(i-1,j-1)&label==img(i,j-1)&label==img(i-1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i,j-1)&label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i+1,j)&label==img(i +1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i-1,j)&label==img(i-1,j+1)&label==img(i ,j+1)) energy = energy-beta;
        else energy = energy+beta;end 

这是一种更快的方法。但它只适用于简单的案例(成对的邻域第一行),而我的案例包括单、三...邻域

C = double(bsxfun(@eq, Imlabel, permute(1:num_class, [1 3 2])));
C(C == 0) = 0;
C(C == 1) = beta;
%% Replace if statement
mask = zeros(3,3); mask(2,2) = 1;
Cpad = convn(C, mask);
Cpad(Cpad == 0) = 0;

mask2 = ones(3,3); mask2(2,2) = 0;
energy = convn(Cpad, mask2, 'valid');

【问题讨论】:

  • 为什么不创建蒙版然后将它们相乘?
  • @AnderBiguri:你认为它会更快吗?
  • 是的。 MATtrix LABoratory 旨在最快地使用矩阵
  • 你能实现它吗?我想和原始代码比较一下计算时间?谢谢
  • 哈哈哈哈哈哈。 没有。你看到你的代码有多长了吗?

标签: algorithm performance matlab image-processing


【解决方案1】:

这是我的尝试。我真的无法判断是否有任何一个对你来说会更快,因为我使用的是 Octave 而不是 MATLAB,而且时间可能会非常不同。例如,for 循环在 Octave 中仍然需要很长时间。您必须对它们进行测试,看看它们的比较情况。

矩阵乘法

作为@AnderBiguri notes in the comments,一种方法是使用矩阵乘法。如果您选择 3x3 社区,请说

nbr = [0 0 0;
       1 0 0;
       1 1 0];

如果您想知道左上角的元素是否为1,您可以通过掩码执行逐元素乘法

mask = [1 0 0;
        0 0 0;
        0 0 0];

result = sum(mask .* nbr);

(我在这里假设邻域是一个二进制矩阵来走捷径。当我得到实际代码时,我将简单地使用nbr == current_class 来实现这一点。)

如果结果与掩码具有相同数量的1 元素,那么您就有了匹配项。在这种情况下,这两者的元素乘法都是零,所以不匹配。

我们可以将nbrmask 制作成向量并使用向量乘法,而不是逐个元素乘法然后对结果的元素求和:

m = mask(:).';
n = nbr(:);
result = m * n;

这将为您提供与先前结果相同的值。如果您有一个掩码矩阵,则可以将其乘以邻域向量并一次获得所有结果。所以第一步是生成25个掩码向量:

masks = [
   0   0   0   0   1   0   0   0   0;
   0   0   0   0   0   1   0   0   0;
   0   0   0   1   0   0   0   0   0;
   0   0   0   0   0   0   0   1   0;
   0   1   0   0   0   0   0   0   0;
   1   0   0   0   0   0   0   0   0;
   0   0   0   0   0   0   0   0   1;
   0   0   0   0   0   0   1   0   0;
   0   0   1   0   0   0   0   0   0;
   1   1   0   0   0   0   0   0   0;
   1   0   0   1   0   0   0   0   0;
   0   0   0   1   0   0   1   0   0;
   0   0   0   0   0   0   1   1   0;
   0   0   0   0   0   0   0   1   1;
   0   0   0   0   0   1   0   0   1;
   0   0   1   0   0   1   0   0   0;
   0   1   1   0   0   0   0   0   0;
   0   0   0   1   0   0   0   1   0;
   0   0   0   0   0   1   0   1   0;
   0   1   0   1   0   0   0   0   0;
   0   1   0   0   0   1   0   0   0;
   1   1   0   1   0   0   0   0   0;
   0   0   0   1   0   0   1   1   0;
   0   0   0   0   0   1   0   1   1;
   0   1   1   0   0   1   0   0   0];

现在,当您将 masks 乘以邻域时,您会立即获得所有结果。然后将结果与masks 的行的总和进行比较,看看哪些匹配。

result = masks * n;
matches = sum(masks, 2) == result;
match_count = sum(matches);

对于每场比赛,我们从能量中减去beta。对于每个不匹配,我们添加beta,所以

possible_matches = 25; %// the number of neighborhood types
energy = -beta * match_count + beta * (possible_matches - match_count);

现在我们要做的就是弄清楚如何从我们的图像中取出所有 3x3 邻域。幸运的是,MATLAB 有 im2col 函数可以做到这一点。更好的是,它只需要图像的有效邻域,所以如果它已经被填充,你就可以开始了。

function G = gibbs(img, beta, classcount)

   masks = [
      0   0   0   0   1   0   0   0   0;
      0   0   0   0   0   1   0   0   0;
      0   0   0   1   0   0   0   0   0;
      0   0   0   0   0   0   0   1   0;
      0   1   0   0   0   0   0   0   0;
      1   0   0   0   0   0   0   0   0;
      0   0   0   0   0   0   0   0   1;
      0   0   0   0   0   0   1   0   0;
      0   0   1   0   0   0   0   0   0;
      1   1   0   0   0   0   0   0   0;
      1   0   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   0   0;
      0   0   0   0   0   0   1   1   0;
      0   0   0   0   0   0   0   1   1;
      0   0   0   0   0   1   0   0   1;
      0   0   1   0   0   1   0   0   0;
      0   1   1   0   0   0   0   0   0;
      0   0   0   1   0   0   0   1   0;
      0   0   0   0   0   1   0   1   0;
      0   1   0   1   0   0   0   0   0;
      0   1   0   0   0   1   0   0   0;
      1   1   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   1   0;
      0   0   0   0   0   1   0   1   1;
      0   1   1   0   0   1   0   0   0];

   [m,n] = size(img);
   possible_matches = size(masks, 1);
   Imlabel = padarray(img, [1 1], 'replicate', 'both');

   col_label = im2col(Imlabel, [3 3], 'sliding');
   target = repmat(sum(masks, 2), [1, m*n]);

   for ii = 1:classcount
      found = masks*(col_label == ii);
      match_count = sum(found == target, 1);
      energy = -beta * match_count + beta*(possible_matches - match_count);
      G(:,:,ii) = reshape(energy, m, n);
   end

end

查找表

如果您查看矩阵乘法解决方案,它会将每个像素的邻域乘以 25 个掩码。对于 1000 x 1000 的图像,这是 1000 x 1000 x 25 x 9 = 225M 乘法。但是只有512 (2^9) 个可能的邻居配置。因此,如果我们弄清楚这 512 个配置是什么,将它们乘以掩码,然后将匹配项相加,我们就得到了一个 512 元素的查找表,我们对图像中的每个邻域要做的就是计算它的指数。以下是使用上面的masks 创建查找表的方法:

possible_neighborhoods = de2bi(0:511, 9).';
found = masks * possible_neighborhoods;
target = repmat(sum(masks, 2), [1, size(found, 2)]);
LUT = sum(found == target, 1);

这几乎是我们之前在每个循环中所做的,但我们正在为所有可能的邻域做这件事,这相当于数字 0:511 的所有位模式。

现在,我们希望查找表中的十进制索引,而不是我们乘以掩码的每个像素的二进制向量。为此,我们可以将conv2 与一个有效地进行二进制到十进制转换的内核一起使用:

k = [1   8   64;
     2  16  128;
     4  32  256];

or

k = [2^0  2^3  2^6
     2^1  2^4  2^7
     2^2  2^5  2^8];

这将为我们提供每个像素的0:511 值,因此我们将一个添加到1:512 并将其用作查找表的索引。完整代码如下:

function G = gibbs2(img, beta, classcount)

   masks = [
      0   0   0   0   1   0   0   0   0;
      0   0   0   0   0   1   0   0   0;
      0   0   0   1   0   0   0   0   0;
      0   0   0   0   0   0   0   1   0;
      0   1   0   0   0   0   0   0   0;
      1   0   0   0   0   0   0   0   0;
      0   0   0   0   0   0   0   0   1;
      0   0   0   0   0   0   1   0   0;
      0   0   1   0   0   0   0   0   0;
      1   1   0   0   0   0   0   0   0;
      1   0   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   0   0;
      0   0   0   0   0   0   1   1   0;
      0   0   0   0   0   0   0   1   1;
      0   0   0   0   0   1   0   0   1;
      0   0   1   0   0   1   0   0   0;
      0   1   1   0   0   0   0   0   0;
      0   0   0   1   0   0   0   1   0;
      0   0   0   0   0   1   0   1   0;
      0   1   0   1   0   0   0   0   0;
      0   1   0   0   0   1   0   0   0;
      1   1   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   1   0;
      0   0   0   0   0   1   0   1   1;
      0   1   1   0   0   1   0   0   0];

   [m,n] = size(img);
   possible_matches = size(masks, 1);
   possible_neighborhoods = de2bi(0:511, 9).';   %'
   found = masks * possible_neighborhoods;
   target = repmat(sum(masks, 2), [1, size(found, 2)]);
   LUT = sum(found == target, 1);
   
   k = [1   8   64;
        2  16  128;
        4  32  256];
        
   Imlabel = padarray(img, [1 1], 'replicate', 'both');

   for ii = 1:classcount
      filterImage = conv2(double(Imlabel == ii), k, 'valid');
      matchImg = LUT(filterImage + 1);
      G(:,:,ii) = -beta * matchImg + beta * (possible_matches - matchImg);
   end
   
end

由于我们对 1000x1000 图像执行的乘法操作要少得多,因此这种方法比我的机器上使用 Octave 的矩阵乘法方法快大约 7 倍。

【讨论】:

  • 真的很棒。我正在测量性能。我稍后会告诉你。我在 matlab 的第二个代码中发现了一个小错误。行 Imlabel == ii (in conv2(Imlabel == ii,...)) 需要转换为 double as double(Imlabel == ii)。
  • 是的,显然 MATLAB 希望 conv2 的参数是单参数或双参数。我会继续改变它。
  • 感谢烧杯。我花时间了解和衡量你的方式。作为我在 MATLAB 中的检查,与我的方式相比,对于 500x500 大小的 4 类矩阵,您的第一种方式加速了 2 倍,第二种方式加速了 13 倍。这确实是一个很好的方法。我理解了第一种方式。但是,我仍然不了解第二种方式。为什么有 2^9 个可能的邻居,而不是 2^10? 4个邻域有多少可能的邻居配置(是2^2吗?))?
  • 它是 2^9,因为附近有 9 个单元格,每个单元格都是二进制的(它要么匹配类,要么不匹配)。对于 4 个连接的邻域,它将是 2^5,因为您有中心单元及其 4 个邻居。
  • 我们可以在chat.stackoverflow.com/rooms/106177/…一起聊天
猜你喜欢
  • 2014-10-07
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2015-08-10
  • 2013-10-03
  • 1970-01-01
  • 2019-10-01
相关资源
最近更新 更多