【问题标题】:Fast, branchless unsigned int absolute difference快速、无分支的无符号整数绝对差
【发布时间】:2014-04-22 02:33:25
【问题描述】:

我有一个程序花费大部分时间计算 RGB 值之间的欧几里得距离(无符号 8 位 Word8 的 3 元组)。我需要一个快速、无分支的 unsigned int 绝对差分函数,这样

unsigned_difference :: Word8 -> Word8 -> Word8
unsigned_difference a b = max a b - min a b

特别是

unsigned_difference a b == unsigned_difference b a

我使用 GHC 7.8 中的新 primops 提出了以下建议:

-- (a < b) * (b - a) + (a > b) * (a - b)
unsigned_difference (I# a) (I# b) =
    I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))]

ghc -O2 -S 编译成

.Lc42U:
    movq 7(%rbx),%rax
    movq $ghczmprim_GHCziTypes_Izh_con_info,-8(%r12)
    movq 8(%rbp),%rbx
    movq %rbx,%rcx
    subq %rax,%rcx
    cmpq %rax,%rbx
    setg %dl
    movzbl %dl,%edx
    imulq %rcx,%rdx
    movq %rax,%rcx
    subq %rbx,%rcx
    cmpq %rax,%rbx
    setl %al
    movzbl %al,%eax
    imulq %rcx,%rax
    addq %rdx,%rax
    movq %rax,(%r12)
    leaq -7(%r12),%rbx
    addq $16,%rbp
    jmp *(%rbp)

使用ghc -O2 -fllvm -optlo -O3 -S 编译会生成以下asm:

.LBB6_1:
    movq    7(%rbx), %rsi
    movq    $ghczmprim_GHCziTypes_Izh_con_info, 8(%rax)
    movq    8(%rbp), %rcx
    movq    %rsi, %rdx
    subq    %rcx, %rdx
    xorl    %edi, %edi
    subq    %rsi, %rcx
    cmovleq %rdi, %rcx
    cmovgeq %rdi, %rdx
    addq    %rcx, %rdx
    movq    %rdx, 16(%rax)
    movq    16(%rbp), %rax
    addq    $16, %rbp
    leaq    -7(%r12), %rbx
    jmpq    *%rax  # TAILCALL

所以 LLVM 设法用(更有效?)条件移动指令替换比较。不幸的是,使用-fllvm 编译对我的程序运行时影响不大。

但是,这个函数有两个问题。

  • 我想比较Word8,但是比较primops需要使用Int。这会导致不必要的分配,因为我不得不存储 64 位 Int 而不是 Word8

我已分析并确认,fromIntegral :: Word8 -&gt; Int 的使用占该程序总分配的 42.4%。

  • 我的版本使用 2 次比较、2 次乘法和 2 次减法。我想知道是否有更有效的方法,使用按位运算或 SIMD 指令并利用我正在比较 Word8 的事实。

我之前已将问题标记为C/C++,以吸引那些更倾向于位操作的人的注意。我的问题使用 Haskell,但我会接受以任何语言实现正确方法的答案。

结论:

我决定使用

w8_sad :: Word8 -> Word8 -> Int16
w8_sad a b = xor (diff + mask) mask
    where diff = fromIntegral a - fromIntegral b
          mask = unsafeShiftR diff 15

因为它比我原来的 unsigned_difference 函数更快,并且易于实现。 Haskell 中的 SIMD 内部函数尚未成熟。因此,虽然 SIMD 版本更快,但我决定使用标量版本。

【问题讨论】:

  • (a - b) &amp; 127 工作吗?
  • @cdk 我怀疑最好的答案是上一级。解释为什么代码需要 RGB 之间的欧几里得距离以及如何使用该值。也许欧几里得距离的平方就足够了。 (a - b)*(a - b)
  • @cdk:我猜 chux 的意思是加法和乘法形成一个以 2^8 为模的环,所以(a-b)*(a-b) = (b-a)*(b-a)
  • 为什么需要abs 函数来计算欧几里得距离?你只需要abs 就可以达到最高标准。
  • x86 SSE2 有一个 psadbw 指令,它为您提供 8 个 Word8 SAD 操作的总和。因此,如果您将 2 个输入字节零扩展为 XMM 寄存器,psadbw 会满足您的需求。它专为对多像素块进行运动搜索的视频编解码器而设计,但对于您的用例,您可以 SSE4.1 pmovzxbq 将 2 个字节加载到 2 个 qwords 中以并行检查 2 个像素分量。也pmuludq 对 2 个结果求平方。我不知道如何让 Haskell 编译器发出它;我根本不知道 Haskell。

标签: performance haskell bit-manipulation simd


【解决方案1】:

编辑:更改我的答案,我为此错误配置了优化。

我在 C 中为此设置了一个快速测试平台,我发现

a - b + (a &lt; b) * ((b - a) &lt;&lt; 1);

头发更好,至少在我的设置中。我的方法的优点是消除了比较。当不需要时,您的版本会隐式处理 a - b == 0,就像它是一个单独的案例一样。

我和你的测试需要

  • 您的实现:371 毫秒
  • 此实现:324 毫秒
  • 加速:14%

我尝试了一种非分支绝对值的方法,结果更好。请注意,编译器是否认为输入或输出已签名是无关紧要的。它围绕大的无符号值循环,但由于它只需要处理小值(如问题所述),它应该就足够了。

s32 diff = a - b;
u32 mask = diff >> 31;
return (diff + mask) ^ mask;
  • 您的实施:371 毫秒
  • 这个实现:241ms
  • 加速:53%

【讨论】:

  • 你是对的,我复制的示例假设 >> 32 具有某种行为,我根据我的 CPU 的功能对其进行了简化,这不是正确的方法。我想出了一个修复方法并更新了我的答案。
【解决方案2】:

好吧,我尝试了一些基准测试。我使用Criterion 进行基准测试,因为它进行了适当的显着性测试。我这里也使用QuickCheck 来确保所有方法返回相同的结果。

我用 GHC 7.6.3 编译(所以很遗憾,我不能包含你的 primops 函数)和-O3

ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff

主要我们可以看到幼稚的实现和一些摆弄之间的区别:

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 a b = max a b - min a b

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
  where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        mask = unsafeShiftR v 63

输出:

benchmarking absdiff_Word8/1
mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
....

benchmarking absdiff_Word8/2
mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
...

我使用“Bit Twiddling Hacks here”中的absolute integer value 技巧。不幸的是,我们需要强制转换,我认为单独在Word8 的域中无法很好地解决问题,但无论如何使用本机整数类型似乎是明智的(虽然绝对不需要创建堆对象)。

看起来差别不大,但我的测试设置也不完美:我将函数映射到大量随机值以排除分支预测,从而使分支版本看起来比实际更有效.这会导致 thunk 在内存中累积,这可能会对时间产生很大影响。当我们减去维护列表的恒定开销时,我们很可能会看到比 20% 的加速要多得多。

生成的程序集其实还不错(这是函数的内联版本):

.Lc4BB:
    leaq 7(%rbx),%rax
    movq 8(%rbp),%rbx
    subq (%rax),%rbx
    movq %rbx,%rax
    sarq $63,%rax
    movq $base_GHCziInt_I64zh_con_info,-8(%r12)
    addq %rax,%rbx
    xorq %rax,%rbx
    movq %rbx,0(%r12)
    leaq -7(%r12),%rbx
    movq $s4z0_info,8(%rbp)

1 次减法,1 次加法,1 次右移,1 次异或,没有分支,正如预期的那样。使用 LLVM 后端不会显着改善运行时。

如果你想尝试更多的东西,希望这对你有用。

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where

import Data.Word
import Data.Int
import Data.Bits
import Control.Arrow ((***))
import Control.DeepSeq (force)
import Control.Exception (evaluate)
import Control.Monad
import System.Random
import Unsafe.Coerce

import Test.QuickCheck hiding ((.&.))
import Criterion.Main

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 !a !b = max a b - min a b

absdiff1_int16 :: Int16 -> Int16 -> Int16
absdiff1_int16 a b = max a b - min a b

absdiff1_int :: Int -> Int -> Int
absdiff1_int a b = max a b - min a b

absdiff2_int16 :: Int16 -> Int16 -> Int16
absdiff2_int16 a b = xor (v + mask) mask
  where v = a - b
        mask = unsafeShiftR v 15

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
  where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        !mask = unsafeShiftR v 63

absdiff3_w8 :: Word8 -> Word8 -> Word8
absdiff3_w8 a b = if a > b then a - b else b - a

{-absdiff4_int :: Int -> Int -> Int-}
{-absdiff4_int (I# a) (I# b) =-}
    {-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}

e2e :: (Enum a, Enum b) => a -> b
e2e = toEnum . fromEnum

prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
    where x' = e2e x
          y' = e2e y

check = quickCheck prop_same1
     >> quickCheck prop_same2

instance (Random x, Random y) => Random (x, y) where
  random gen1 =
    let (x, gen2) = random gen1
        (y, gen3) = random gen2
    in ((x,y),gen3)

main =
    do check
       !pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
       let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
       defaultMain
         [ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
                                  , bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
                                  , bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
                                  ]
         , bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
                                  , bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
                                  ]
         {-, bgroup "absdiff_Int"   [ bench "1" $ whnf (absdiff1_int 13) 14-}
                                  {-, bench "2" $ whnf (absdiff3_int 13) 14-}
                                  {-]-}
         ]

【讨论】:

    【解决方案3】:

    如果您的目标是具有 SSE 指令的系统,您可以使用它来提高性能。我针对其他发布的方法对此进行了测试,这似乎是最快的方法。

    区分大量值的示例结果:

    diff0: 188.020679 ms // branching
    diff1: 118.934970 ms // max min
    diff2: 97.087710 ms  // branchless mul add
    diff3: 54.495269 ms  // branchless signed
    diff4: 31.159628 ms  // sse
    diff5: 30.855885 ms  // sse v2
    

    下面是我的完整测试代码。我通过 SSE 内在函数使用了 SSE2 指令,这些指令现在在 x86ish CPU 中广泛可用,它应该是非常可移植的(MSVC、GCC、Clang、Intel 编译器等)。

    注意事项:

    • 实际上,这会先计算最大值,然后再计算最小值,然后再减去,但每条指令一次会计算 16 个值。
    • diff5 中展开它似乎效果不大,但可能可以调整。
    • 最后 15 个或更少值的回退当前在循环中使用带符号的技巧方法,但可能会通过展开和/或 SSE 进一步加快速度。
    • 函数本身非常简单,因此它们应该可以轻松移植到任何具有 SSE 内部函数或 asm 的东西。
    • 我使用了 Windows 特定的计时函数,因为 std::chrono::high_resolution_clock 在 MSVC 实现中的精度很低,对此感到抱歉,并且对于 C/C++ 测试代码的肮脏混合。
    • 对性能进行计时后,将针对参考分支实现对结果进行测试,因此它们应该是正确的。

    如果您对代码或此方法有任何疑问/建议,请发表评论。

    #include <cstdlib>
    #include <cstdint>
    #include <cstdio>
    #include <cmath>
    #include <random>
    #include <algorithm>
    
    #define WIN32_LEAN_AND_MEAN
    #define NOMINMAX
    #include <Windows.h>
    
    #include <emmintrin.h> // sse2
    
    // branching
    void diff0(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
        std::size_t n)
    {
        for (std::size_t i = 0; i < n; i++) {
            res[i] = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
        }
    }
    
    // max min
    void diff1(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
        std::size_t n)
    {
        for (std::size_t i = 0; i < n; i++) {
            res[i] = std::max(a[i], b[i]) - std::min(a[i], b[i]);
        }
    }
    
    // branchless mul add
    void diff2(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
        std::size_t n)
    {
        for (std::size_t i = 0; i < n; i++) {
            res[i] = (a[i] > b[i]) * (a[i] - b[i]) + (a[i] < b[i]) * (b[i] - a[i]);
        }
    }
    
    // branchless signed
    void diff3(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
        std::size_t n)
    {
        for (std::size_t i = 0; i < n; i++) {
            std::int16_t  diff = a[i] - b[i];
            std::uint16_t mask = diff >> 15;
            res[i] = (diff + mask) ^ mask;
        }
    }
    
    // sse
    void diff4(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
        std::size_t n)
    {
        auto pA = reinterpret_cast<const __m128i*>(a);
        auto pB = reinterpret_cast<const __m128i*>(b);
        auto pRes = reinterpret_cast<__m128i*>(res);
        std::size_t i = 0;
        for (std::size_t j = n / 16; j--; i++) {
            __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
            __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
            _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
        }
        for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
            std::int16_t  diff = a[i] - b[i];
            std::uint16_t mask = diff >> 15;
            res[i] = (diff + mask) ^ mask;
        }
    }
    
    // sse v2
    void diff5(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
        std::size_t n)
    {
        auto pA = reinterpret_cast<const __m128i*>(a);
        auto pB = reinterpret_cast<const __m128i*>(b);
        auto pRes = reinterpret_cast<__m128i*>(res);
        std::size_t i = 0;
        const std::size_t UNROLL = 2;
        for (std::size_t j = n / (16 * UNROLL); j--; i += UNROLL) {
            __m128i max0 = _mm_max_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
            __m128i min0 = _mm_min_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
            __m128i max1 = _mm_max_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
            __m128i min1 = _mm_min_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
            _mm_store_si128(pRes + i + 0, _mm_sub_epi8(max0, min0));
            _mm_store_si128(pRes + i + 1, _mm_sub_epi8(max1, min1));
        }
        for (std::size_t j = n % (16 * UNROLL) / 16; j--; i++) {
            __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
            __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
            _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
        }
        for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
            std::int16_t  diff = a[i] - b[i];
            std::uint16_t mask = diff >> 15;
            res[i] = (diff + mask) ^ mask;
        }
    }
    
    int main() {
        const std::size_t ALIGN = 16; // sse requires 16 bit align
        const std::size_t N = 10 * 1024 * 1024 * 3;
    
        auto a = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
        auto b = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
    
        { // fill with random values
            std::mt19937 engine(std::random_device{}());
            std::uniform_int<std::uint8_t> distribution(0, 255);
            for (std::size_t i = 0; i < N; i++) {
                a[i] = distribution(engine);
                b[i] = distribution(engine);
            }
        }
    
        auto res0 = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff0 results
        auto resX = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff1+ results
    
        LARGE_INTEGER f, t0, t1;
        QueryPerformanceFrequency(&f);
    
        QueryPerformanceCounter(&t0);
        diff0(a, b, res0, N);
        QueryPerformanceCounter(&t1);
        printf("diff0: %.6f ms\n",
            static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);
    
    #define TEST(diffX)\
        QueryPerformanceCounter(&t0);\
        diffX(a, b, resX, N);\
        QueryPerformanceCounter(&t1);\
        printf("%s: %.6f ms\n", #diffX,\
            static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);\
        for (std::size_t i = 0; i < N; i++) {\
            if (resX[i] != res0[i]) {\
                printf("error: %s(%03u, %03u) == %03u != %03u\n", #diffX,\
                    a[i], b[i], resX[i], res0[i]);\
                break;\
            }\
        }
    
        TEST(diff1);
        TEST(diff2);
        TEST(diff3);
        TEST(diff4);
        TEST(diff5);
    
        _mm_free(a);
        _mm_free(b);
        _mm_free(res0);
        _mm_free(resX);
    
        getc(stdin);
        return 0;
    }
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2013-01-12
      • 2015-11-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多