【问题标题】:optimization, branching elimination优化,分支消除
【发布时间】:2018-04-21 05:55:32
【问题描述】:
float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
    float inputLevel = ... //in range -1.0f to 1.0f
    if(inputLevel < 0.0 && mixValue < 0.0)
    {
        mixValue = (mixValue + inputLevel) + (mixValue*inputLevel);
    }
    else
    {
        mixValue = (mixValue + inputLevel) - (mixValue*inputLevel);
    }
}

只是一个简单的问题,我们可以计算mixValue 没有分支吗?或任何其他优化建议,例如使用 SIMD?

编辑: 只是为了更多信息,我结束了 根据选择的答案使用此解决方案:

const float sign[] = {-1, 1};
float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
    float inputLevel = ... //in range -1.0f to 1.0f
    unsigned a = *(unsigned*)(&mixValue);
    unsigned b = *(unsigned*)(&inputLevel);

    float mulValue = mixValue * inputLevel * sign[(a & b) >> (8*sizeof(unsigned)-1)];
    float addValue = mixValue + inputLevel;
    mixValue = addValue + mulValue;
}

谢谢。

【问题讨论】:

  • 你确定这正是你想要做的吗?
  • 我敢肯定,它工作得很好,作为参考,您可以参考 [-1.0f,1.0f] 范围内浮点样本的音频波混合算法
  • 请注意,如果 mixValueinputLevel 为 0.0,则两个分支是相同的。此外,如果inputLevel 为0.0,您实际上不需要做任何事情。但我也怀疑公式是错误的。这样的公式通常是奇数或偶数; f(-x)==f(x)f(-x)==-f(x)。你的也不是。
  • 没错,如果上一个mixValue==0 然后mixValue=inputLevel,如果下一个inputLevel==0 然后mixValue 保持不变。您可以看到,就好像您将可听声音和静音声音混合在一起一样,不会有任何变化。但是如果你混合两个嘈杂的声音,并且这些声音的峰值是满足的,你就不能添加这两个值,因为它会超出最大值。
  • 这篇文章很旧,但我还是会发表评论。您可以使用以下内容计算不乘的符号:(float)(-1 + (int)(((a &amp; b) &gt;&gt; (8*sizeof(unsigned)-1)) &lt;&lt; 1))(数组索引具有隐式乘法)

标签: c++ optimization floating-point branch


【解决方案1】:

这个怎么样:

const float sign[] = {-1, 1};

float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
    float inputLevel = ... //in range -1.0f to 1.0f
    int bothNegative = (inputLevel < 0.0) & (mixValue < 0.0);
    mixValue = (mixValue + inputLevel) + (sign[bothNegative]*mixValue*inputLevel);
}

编辑: Mike 是正确的,&& 会引入一个分支,感谢 Pedro 证明了这一点。我将 && 更改为 &,现在 GCC(4.4.0 版)生成无分支代码。

【讨论】:

  • 问题是:如果bothNefative为false,则等于0,所以永远不可能是负数。
  • @Klaim: sign[0] 是-1,所以sign[bothNegative]bothNegative==0 是-1
  • 啊,是的,我没看到数组。这很聪明!我花了太多时间试图得到这样的东西,我从来没有想过数组中的预定义值 XD 如此简单......
  • &amp;&amp; 可能会引入一个分支;如果您真的想确定,请使用&amp;
  • 在 MSVC++10 上编译,发布,这确实是分支:img641.imageshack.us/img641/7063/floattestbranch.png(IDA 截图)。将 &amp;&amp; 替换为 &amp; 会产生类似的结果。
【解决方案2】:

受 Roku 的回答(在 MSVC++10 分支上)的启发,这似乎没有分支:

#include <iostream>

using namespace std;
const float sign[] = {-1, 1};
int main() {
    const int N = 10;
    float mixValue = -0.5F;
    for(int i = 0; i < N; i++) {
        volatile float inputLevel = -0.3F;
        int bothNegative = ((((unsigned char*)&inputLevel)[3] & 0x80) & (((unsigned char*)&mixValue)[3] & 0x80)) >> 7;
        mixValue = (mixValue + inputLevel) + (sign[bothNegative]*mixValue*inputLevel);
    }

    std::cout << mixValue << std::endl;
}

这是 IDA Pro 分析的反汇编(在 MSVC++10 上编译,发布模式):

Disassembly http://img248.imageshack.us/img248/6865/floattestbranchmine.png

【讨论】:

  • 只是为了确保编译器不会优化它。
  • 免责声明:与大多数位旋转代码一样,它依赖于内置类型(此处为浮点)在内存中的表示,并且不能假定为可移植的(32 / 64 位等...)
【解决方案3】:
float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
     float inputLevel = ... //in range -1.0f to 1.0f
     float mulValue = mixValue * inputLevel;
     float addValue = mixValue + inputLevel;
     __int32 a = *(__int32*)(&mixValue);
     __int32 b = *(__int32*)(&inputLevel);
     __int32 c = *(__int32*)(&mulValue);
     __int32 d = c & ((a ^ b) | 0x7FFFFFFF);
     mixValue = addValue + *(float*)(&d);
}

【讨论】:

    【解决方案4】:

    就在我的头顶(我相信它可以减少):

    mixValue = (mixValue + inputLevel) + (((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1) / fabs(((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1)))*-1*(mixValue*inputLevel);

    为了澄清一点,我将单独计算符号:

    float sign = (((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1) / fabs(((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1)))*-1;
    mixValue = (mixValue + inputLevel) + sign*(mixValue*inputLevel);
    

    这是浮点数学,因此您可能需要纠正一些舍入问题,但我认为这应该让您走上正确的道路。

    【讨论】:

    • 我敢打赌,除法比分支效率还要低。
    • @NullUserException: fabs() 可以在没有分支的情况下计算。
    【解决方案5】:

    如果您担心过度分支,请查看Duff's Device。这应该有助于在某种程度上解开循环。说实话,循环展开是由优化器完成的,所以尝试手动完成可能是浪费时间。检查汇编输出以找出答案。

    如果您对数组中的每个项目执行完全相同的操作,SIMD 肯定会有所帮助。请注意,并非所有硬件都支持 SIMD,但一些编译器(如 gcc)确实为 SIMD 提供了内在函数,这将使您免于陷入汇编程序。

    如果您使用 gcc 编译 ARM 代码,可以在 here 找到 SIMD 内部函数

    【讨论】:

    • 我已经看到了 asm 输出,它确实展开了循环,但仍然存在分支,它创建了两个代码路径,并且没有应用 SIMD,尽管它使用了类似 mulss 的指令或在 xmm reg 上添加了指令
    • 您的代码的问题是 mixValue 在迭代之间发生了变化,所以我猜这里不可能有 SIMD。
    【解决方案6】:

    您是否对带有和不带有分支的循环进行了基准测试?

    至少你可以删除分支的一部分,因为 mixValue 在循环之外。

    float multiplier(float a, float b){
      unsigned char c1Neg = reinterpret_cast<unsigned char *>(&a)[3] & 0x80;
      unsigned char c2Neg = reinterpret_cast<unsigned char *>(&b)[3] & 0x80;
      unsigned char multiplierIsNeg = c1Neg & c2Neg;
      float one = 1;
      reinterpret_cast<unsigned char *>(&one)[3] |= multiplierIsNeg;
      return -one;
    }
    cout << multiplier(-1,-1) << endl; // +1
    cout << multiplier( 1,-1) << endl; // -1
    cout << multiplier( 1, 1) << endl; // -1
    cout << multiplier(-1, 1) << endl; // -1
    

    【讨论】:

    • 我的问题是消除分支,基准测试结果不在这个问题中。
    • mixValue 是循环依赖变量,请参阅:mixValue = (mixValue + ...
    【解决方案7】:

    查看您的代码,您会发现您将始终将 mixValueinputLevel 的绝对值相加,除非两者都是正数。

    有了一些位摆弄和 IEEE 浮点知识,您可能会摆脱条件:

    // sets the first bit of f to zero => makes it positive.
    void absf( float& f ) {
       assert( sizeof( float ) == sizeof( int ) );
       reinterpret_cast<int&>( f ) &= ~0x80000000;
    }
    
    // returns a first-bit = 1 if f is positive
    int pos( float& f ) {
      return ~(reinterpret_cast<int&>(f) & 0x80000000) & 0x80000000;
    }
    
    // returns -fabs( f*g ) if f>0 and g>0, fabs(f*g) otherwise.    
    float prod( float& f, float& g ) {
      float p = f*g;
      float& rp=p;
      int& ri = reinterpret_cast<int&>(rp);
      absf(p);
      ri |= ( pos(f) & pos(g) & 0x80000000); // first bit = + & +
      return p;
    }
    
    int main(){
     struct T { float f, g, r; 
        void test() {
           float p = prod(f,g);
           float d = (p-r)/r;
           assert( -1e-15 < d && d < 1e-15 );
        }
     };
     T vals[] = { {1,1,-1},{1,-1,1},{-1,1,1},{-1,-1,1} };
     for( T* val=vals; val != vals+4; ++val ) {
        val->test();
     }
    }
    

    最后:你的循环

    for( ... ) {
        mixedResult += inputLevel + prod(mixedResult,inputLevel);
    }
    

    注意:您的累积尺寸不匹配。 inputLevel 是无量纲量,而 mixedResult 是您的……结果(例如,帕斯卡、伏特……)。您不能添加具有不同维度的两个数量。可能你想要mixedResult += prod( mixedResult, inputLevel ) 作为你的累加器。

    【讨论】:

      【解决方案8】:

      一些编译器(即 MSC)也需要手动检查符号。

      来源:

      volatile float mixValue;
      volatile float inputLevel;
      
      float u   = mixValue*inputLevel;
      float v   = -u;
      float a[] = { v, u };
      
      mixValue = (mixValue + inputLevel) + a[ (inputLevel<0.0) & (mixValue<0.0) ];
      

      英特尔C 11.1:

      movss     xmm1, DWORD PTR [12+esp]    
      mulss     xmm1, DWORD PTR [16+esp]    
      movss     xmm6, DWORD PTR [12+esp]    
      movss     xmm2, DWORD PTR [16+esp]    
      movss     xmm3, DWORD PTR [16+esp]    
      movss     xmm5, DWORD PTR [12+esp]    
      xorps     xmm4, xmm4                  
      movaps    xmm0, xmm4                  
      subss     xmm0, xmm1                  
      movss     DWORD PTR [esp], xmm0       
      movss     DWORD PTR [4+esp], xmm1     
      addss     xmm6, xmm2                  
      xor       eax, eax                    
      cmpltss   xmm3, xmm4                  
      movd      ecx, xmm3                   
      neg       ecx                         
      cmpltss   xmm5, xmm4                  
      movd      edx, xmm5                   
      neg       edx                         
      and       ecx, edx                    
      addss     xmm6, DWORD PTR [esp+ecx*4] 
      movss     DWORD PTR [12+esp], xmm6    
      

      gcc 4.5:

      flds    32(%esp)
      flds    16(%esp)
      fmulp   %st, %st(1)
      fld     %st(0)
      fchs
      fstps   (%esp)
      fstps   4(%esp)
      flds    32(%esp)
      flds    16(%esp)
      flds    16(%esp)
      flds    32(%esp)
      fxch    %st(2)
      faddp   %st, %st(3)
      fldz
      fcomi   %st(2), %st
      fstp    %st(2)
      fxch    %st(1)
      seta    %dl
      xorl    %eax, %eax
      fcomip  %st(1), %st
      fstp    %st(0)
      seta    %al
      andl    %edx, %eax
      fadds   (%esp,%eax,4)
      xorl    %eax, %eax
      fstps   32(%esp)
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2015-11-24
        • 2023-03-22
        • 2013-09-04
        • 2018-07-02
        • 2015-03-21
        • 2019-04-15
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多