【问题标题】:How to speed up this mex code?如何加快这个 mex 代码?
【发布时间】:2012-08-29 15:05:20
【问题描述】:

我正在用 mex 重新编程一段 MATLAB 代码(使用 C)。到目前为止,我的 C 版本的 MATLAB 代码大约是 MATLAB 代码的两倍。现在我有三个问题,都与下面的代码有关:

  1. 我怎样才能加快这段代码的速度?
  2. 您是否发现此代码有任何问题?我问这个是因为我不太了解 mex,而且我也不是 C 大师 ;-) ...我知道代码中应该有一些检查(例如,使用 @ 时是否仍有堆空间987654321@,但为了简单起见,我暂时不使用它)
  3. 有没有可能,MATLAB 的优化如此之好,以至于我真的无法在 C 语言中获得比两倍快的代码...?

代码应该或多或少独立于平台(Win、Linux、Unix、Mac、不同的硬件),所以我不想使用汇编程序或特定的线性代数库。所以这就是我自己编程员工的原因......

#include <mex.h>
#include <math.h>
#include <matrix.h>

void mexFunction(
    int nlhs, mxArray *plhs[],
    int nrhs, const mxArray *prhs[])
{
    double epsilon = ((double)(mxGetScalar(prhs[0])));
    int strengthDim = ((int)(mxGetScalar(prhs[1])));
    int lenPartMat = ((int)(mxGetScalar(prhs[2])));
    int numParts = ((int)(mxGetScalar(prhs[3])));
    double *partMat = mxGetPr(prhs[4]);
    const mxArray* verletListCells = prhs[5];
    mxArray *verletList;

    double *pseSum = (double *) malloc(numParts * sizeof(double));
    for(int i = 0; i < numParts; i++) pseSum[i] = 0.0;

    float *tempVar = NULL;

    for(int i = 0; i < numParts; i++)
    {
        verletList = mxGetCell(verletListCells,i);
        int numberVerlet = mxGetM(verletList);

        tempVar = (float *) realloc(tempVar, numberVerlet * sizeof(float) * 2);


        for(int a = 0; a < numberVerlet; a++)
        {
            tempVar[a*2] = partMat[((int) (*(mxGetPr(verletList) + a))) - 1] - partMat[i];
            tempVar[a*2 + 1] = partMat[((int) (*(mxGetPr(verletList) + a))) - 1 + lenPartMat] - partMat[i + lenPartMat];

            tempVar[a*2] = pow(tempVar[a*2],2);
            tempVar[a*2 + 1] = pow(tempVar[a*2 + 1],2);

            tempVar[a*2] = tempVar[a*2] + tempVar[a*2 + 1];
            tempVar[a*2] = sqrt(tempVar[a*2]);

            tempVar[a*2] = 4.0/(pow(epsilon,2) * M_PI) * exp(-(pow((tempVar[a*2]/epsilon),2)));
            pseSum[i] = pseSum[i] + ((partMat[((int) (*(mxGetPr(verletList) + a))) - 1 + 2*lenPartMat] - partMat[i + (2 * lenPartMat)]) * tempVar[a*2]);
        }

    }

    plhs[0] = mxCreateDoubleMatrix(numParts,1,mxREAL);
    for(int a = 0; a < numParts; a++)
    {
        *(mxGetPr(plhs[0]) + a) = pseSum[a];
    }

    free(tempVar);
    free(pseSum);
}

所以这是改进版,比 MATLAB 版快 12 倍左右。转换的事情仍然占用了很多时间,但我暂时放弃了,因为我必须为此在 MATLAB 中进行一些更改。所以首先关注剩下的 C 代码。您在以下代码中看到了更多潜力吗?

#include <mex.h>
#include <math.h>
#include <matrix.h>

void mexFunction(
    int nlhs, mxArray *plhs[],
    int nrhs, const mxArray *prhs[])
{
    double epsilon = ((double)(mxGetScalar(prhs[0])));
    int strengthDim = ((int)(mxGetScalar(prhs[1])));
    int lenPartMat = ((int)(mxGetScalar(prhs[2])));
    double *partMat = mxGetPr(prhs[3]);
    const mxArray* verletListCells = prhs[4];
    int numParts = mxGetM(verletListCells);
    mxArray *verletList;

    plhs[0] = mxCreateDoubleMatrix(numParts,1,mxREAL);
    double *pseSum = mxGetPr(plhs[0]);

    double epsilonSquared = epsilon*epsilon;

    double preConst = 4.0/((epsilonSquared) * M_PI);

    int numberVerlet = 0;

    double tempVar[2];

    for(int i = 0; i < numParts; i++)
    {
        verletList = mxGetCell(verletListCells,i);
        double *verletListPtr = mxGetPr(verletList);
        numberVerlet = mxGetM(verletList);

        for(int a = 0; a < numberVerlet; a++)
        {
            int adress = ((int) (*(verletListPtr + a))) - 1;

            tempVar[0] = partMat[adress] - partMat[i];
            tempVar[1] = partMat[adress + lenPartMat] - partMat[i + lenPartMat];

            tempVar[0] = tempVar[0]*tempVar[0] + tempVar[1]*tempVar[1];

            tempVar[0] = preConst * exp(-(tempVar[0]/epsilonSquared));
            pseSum[i] += ((partMat[adress + 2*lenPartMat] - partMat[i + (2*lenPartMat)]* tempVar[0]);
        }

    }

}

【问题讨论】:

  • 你能把原始的Matlab代码也贴出来吗?通常,最佳的速度优化是在算法设计级别执行的。

标签: c performance matlab optimization mex


【解决方案1】:
  • 您无需分配 pseSum 以供本地使用,然后再将数据复制到输出。您可以简单地分配一个 MATLAB 对象并获取指向内存的指针:

    plhs[0] = mxCreateDoubleMatrix(numParts,1,mxREAL);
    pseSum  = mxGetPr(plhs[0]);
    

因此您不必将 pseSum 初始化为 0,因为 MATLAB 已经在 mxCreateDoubleMatrix 中完成了。

  • 从内循环中移除所有的mxGetPr,并将它们分配给之前的变量。

  • 考虑在 MATLAB 中使用 int32 或 uint32 数组,而不是将双精度数转换为整数。将 double 转换为 int 是昂贵的。内部循环计算看起来像

    tempVar[a*2] = partMat[somevar[a] - 1] - partMat[i];
    

    您在代码中使用此类结构

    ((int) (*(mxGetPr(verletList) + a)))
    

    您这样做是因为 varletList 是一个“双”数组(在 MATLAB 中默认就是这种情况),其中包含整数值。相反,您应该使用整数数组。在 MATLAB 中调用 mex 文件类型之前:

    varletList = int32(varletList);
    

    那么您将不需要上面的类型转换为 int。你会简单地写

    ((int*)mxGetData(verletList))[a]
    

    或者更好的是,更早分配

    somevar = (int*)mxGetData(verletList);
    

    后来写

    somevar[a]
    
  • 在所有循环之前预计算 4.0/(pow(epsilon,2) * M_PI)!这是一个昂贵的常数。

  • pow((tempVar[a*2]/epsilon),2)) 就是 tempVar[a*2]^2/epsilon^2。您刚刚计算 sqrt(tempVar[a*2]) 。为什么现在就摆正?

  • 一般不使用 pow(x, 2)。就写 x*x

  • 我会在参数上添加一些健全性检查,特别是如果您需要整数。要么使用 MATLABs int32/uint32 类型,要么检查你得到的实际是一个整数。

编辑在新代码中

  • 在循环之前计算 -1/epsilonSquared 并计算 exp(minvepssq*tempVar[0])。请注意,结果可能略有不同。取决于你需要什么,但如果你不关心操作的确切顺序,那就去做吧。

  • 定义一个寄存器变量 preSum_r 并使用它对内部循环中的结果求和。在循环之后将其分配给 preSum[i]。如果您想要更多乐趣,可以使用 SSE 流式存储(_mm_stream_pd 编译器内在)将结果写入内存。

  • 将 double 移除为 int cast

  • 很可能不相关,但尝试将 tempVar[0/1] 更改为正常变量。无关紧要,因为编译器应该为您执行此操作。但同样,这里不需要数组。

  • 将外部循环与 OpenMP 并行化。微不足道(至少是最简单的版本,没有考虑 NUMA 架构的数据布局),因为迭代之间没有依赖关系。

【讨论】:

  • 谢谢你们的cmets!!!其中一些真的很简单,或者只是基本的数学。我自己应该已经看到了 :-( 现在我的代码比 Matlab 代码快了大约 10 倍,这非常令人满意。我实现了你的所有提示,除了来自 angainor 关于双打演员的提示。你能告诉我一个更多关于那个,不太明白......(特别是怎么做)。为什么 x*x 比 pow(x,2) 更好?
  • 你是对的,它可能无关紧要——如果编译器相当好,它会发现这种优化,而不是调用 pow 函数,这在一般情况下更昂贵,只会执行 X *X。但为什么要冒险呢?与循环内计算的昂贵常数相同。编译器可能会找到它,但为什么要把它放在首位呢?最后,您应该只进行基准测试和检查..
  • 另一个问题。对此进行了一些谷歌搜索,但我仍然不确定。 C 中有没有更快的 exp() 版本???当然,这个优化到了最大,但也许有些精度有所损失,但仍然足以满足我的目的......
  • 我建议发布一个关于 exp 的新问题。在此处链接此问题以显示您拥有的代码类型。
  • 您好,谢谢您的建议。现在使用 openMP,这可以加快整个过程!!!所以这现在真的足够快了!
【解决方案2】:

您能否提前估计tempVar 的最大大小并在循环之前为其分配内存而不是使用realloc?重新分配内存是一项耗时的操作,如果您的 numParts 很大,这可能会产生巨大的影响。看看this question

【讨论】:

    猜你喜欢
    • 2012-02-11
    • 2023-03-16
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-10-05
    • 1970-01-01
    相关资源
    最近更新 更多