【问题标题】:Fastest way to calculate a digit-sum for a large number (as a decimal string)计算大数数字和的最快方法(作为十进制字符串)
【发布时间】:2020-12-20 09:16:11
【问题描述】:

我使用 gmplib 获取大数并计算数值(数字总和:123 -> 674 -> 11 -> 2

这就是我所做的:

unsigned short getnumericvalue(const char *in_str)
{
    unsigned long number = 0;
    const char *ptr = in_str;
     
     do {
         if (*ptr != '9') number += (*ptr - '0'); // Exclude '9'
         ptr++;
     } while (*ptr != 0);
     
     unsigned short reduced = number % 9;
    
     return reduced == 0 ? 9 : reduced;
}

效果很好,但在 Xeon w-3235 上是否有更快的方法?

【问题讨论】:

  • 你怎么知道你需要让你的代码更快?您对其进行了剖析,发现这种转换是性能瓶颈?
  • 将 bignum 转换为字符串将比对数字求和慢得多,因此这种优化似乎毫无意义。
  • 请注意,如果您使用字符文字而不是硬编码 ASCII 值,则代码更易于阅读。
  • 为什么不直接计算N % 9(不先将您的数字转换为字符串)并将0替换为9,除非N==0?还是您的输入是字符串?
  • @Stephane 这有什么关系?在最坏的情况下,该部门将需要几个周期,但它只运行一次。条件移动运行循环的每次迭代。这并不快。

标签: c assembly sse intrinsics avx512


【解决方案1】:

您可以使用如下代码。算法的总体思路是:

  1. 按字节处理数据,直到达到高速缓存行对齐
  2. 一次读取一个缓存行,检查字符串的结尾,并将数字添加到 8 个累加器中
  3. 将 8 个累加器减为 1,并从头部添加计数
  4. 按字节处理余数

请注意,下面的代码尚未经过测试。

        // getnumericvalue(ptr)
        .section .text
        .type getnumericvalue, @function
        .globl getnumericvalue
getnumericvalue:
        xor %eax, %eax          // digit counter

        // process string until we reach cache-line alignment
        test $64-1, %dil        // is ptr aligned to 64 byte?
        jz 0f

1:      movzbl (%rdi), %edx     // load a byte from the string
        inc %rdi                // advance pointer
        test %edx, %edx         // is this the NUL byte?
        jz .Lend                // if yes, finish this function
        sub $'0', %edx          // turn ASCII character into digit
        add %edx, %eax          // and add to counter
        test $64-1, %dil        // is ptr aligned to 64 byte?
        jnz 1b                  // if not, process more data

        // process data in cache line increments until the end
        // of the string is found somewhere
0:      vpbroadcastd zero(%rip), %zmm1  // mask of '0' characters
        vpxor %xmm3, %xmm3, %xmm3       // vectorised digit counter

        vmovdqa32 (%rdi), %zmm0         // load one cache line from the string
        vptestmb %zmm0, %zmm0, %k0      // clear k0 bits if any byte is NUL
        kortestq %k0, %k0               // clear CF if a NUL byte is found
        jnc 0f                          // skip loop if a NUL byte is found

        .balign 16
1:      add $64, %rdi                   // advance pointer
        vpsadbw %zmm1, %zmm0, %zmm0     // sum groups of 8 bytes into 8 words
                                        // also subtracts '0' from each byte
        vpaddq %zmm3, %zmm0, %zmm3      // add to counters
        vmovdqa32 (%rdi), %zmm0         // load one cache line from the string
        vptestmb %zmm0, %zmm0, %k0      // clear k0 bits if any byte is NUL
        kortestq %k0, %k0               // clear CF if a NUL byte is found
        jc 1b                           // go on unless a NUL byte was found

        // reduce 8 vectorised counters into rdx
0:      vextracti64x4 $1, %zmm3, %ymm2  // extract high 4 words
        vpaddq %ymm2, %ymm3, %ymm3      // and add them to the low words
        vextracti128 $1, %ymm3, %xmm2   // extract high 2 words
        vpaddq %xmm2, %xmm3, %xmm3      // and add them to the low words
        vpshufd $0x4e, %xmm3, %xmm2     // swap qwords into xmm2
        vpaddq %xmm2, %xmm3, %xmm3      // and add to xmm0
        vmovq %xmm3, %rdx               // move digit counter back to rdx
        add %rdx, %rax                  // and add to counts from scalar head

        // process tail
1:      movzbl (%rdi), %edx     // load a byte from the string
        inc %rdi                // advance pointer
        test %edx, %edx         // is this the NUL byte?
        jz .Lend                // if yes, finish this function
        sub $'0', %edx          // turn ASCII character into digit
        add %rdx, %rax          // and add to counter
        jnz 1b                  // if not, process more data

.Lend:  xor %edx, %edx          // zero-extend RAX into RDX:RAX
        mov $9, %ecx            // divide by 9
        div %rcx                // perform division
        mov %edx, %eax          // move remainder to result register
        test %eax, %eax         // is the remainder zero?
        cmovz %ecx, %eax        // if yes, set remainder to 9
        vzeroupper              // restore SSE performance
        ret                     // and return
        .size getnumericvalue, .-getnumericvalue

        // constants
        .section .rodata
        .balign 4
zero:   .byte '0', '0', '0', '0'

【讨论】:

  • 评论不用于扩展讨论;这个对话是moved to chat
  • 正如@Stephane 在聊天中指出的那样,.align 16 在 Apple clang 中表示 .p2align 16,通过引入 大量 的 NOP 会降低性能。切勿使用.align,始终使用.balign.p2align 以避免这种歧义。
【解决方案2】:

这是一个可移植的解决方案:

  • 它天真地处理前几位数字,直到 ptr 正确对齐。
  • 然后它一次循环读取 8 个数字,并将这些数字成对相加到一个累加器中。在将 64 位累加器拆分为 number 之前,最多可以执行 28 次此类操作。
  • 终止测试验证包中的所有数字都具有等于3的高半字节。
  • 剩余的数字会一一处理。
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

unsigned getnumericvalue_simple(const char *in_str) {
    unsigned long number = 0;
    const char *ptr = in_str;

    do {
        if (*ptr != '9') number += (*ptr - '0'); // Exclude '9'
        ptr++;
    } while (*ptr != 0);

    return number <= 9 ? number : ((number - 1) % 9) + 1;
}

unsigned getnumericvalue_naive(const char *ptr) {
    unsigned long number = 0;

    while (*ptr) {
        number += *ptr++ - '0';
    }
    return number ? 1 + (number - 1) % 9 : 0;
}

unsigned getnumericvalue_parallel(const char *ptr) {
    unsigned long long number = 0;
    unsigned long long pack1, pack2;

    /* align source on ull boundary */
    while ((uintptr_t)ptr & (sizeof(unsigned long long) - 1)) {
        if (*ptr == '\0')
            return number ? 1 + (number - 1) % 9 : 0;
        number += *ptr++ - '0';
    }

    /* scan 8 bytes at a time */
    for (;;) {
        pack1 = 0;
#define REP8(x) x;x;x;x;x;x;x;x
#define REP28(x) x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x
        REP28(pack2 = *(const unsigned long long *)(const void *)ptr;
              pack2 -= 0x3030303030303030;
              if (pack2 & 0xf0f0f0f0f0f0f0f0)
                  break;
              ptr += sizeof(unsigned long long);
              pack1 += pack2);
        REP8(number += pack1 & 0xFF; pack1 >>= 8);
    }
    REP8(number += pack1 & 0xFF; pack1 >>= 8);

    /* finish trailing bytes */
    while (*ptr) {
        number += *ptr++ - '0';
    }
    return number ? 1 + (number - 1) % 9 : 0;
}

int main(int argc, char *argv[]) {
    clock_t start;
    unsigned naive_result, simple_result, parallel_result;
    double naive_time, simple_time, parallel_time;
    int digits = argc < 2 ? 1000000 : strtol(argv[1], NULL, 0);
    char *p = malloc(digits + 1);
    for (int i = 0; i < digits; i++)
        p[i] = "0123456789123456"[i & 15];
    p[digits] = '\0';

    start = clock();
    simple_result = getnumericvalue_simple(p);
    simple_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;

    start = clock();
    naive_result = getnumericvalue_naive(p);
    naive_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;

    start = clock();
    parallel_result = getnumericvalue_parallel(p);
    parallel_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;

    printf("simple:   %d digits -> %u, %7.3f msec\n", digits, simple_result, simple_time);
    printf("naive:    %d digits -> %u, %7.3f msec\n", digits, naive_result, naive_time);
    printf("parallel: %d digits -> %u, %7.3f msec\n", digits, parallel_result, parallel_time);

    return 0;
}

一亿位的时序:

simple:   100000000 digits -> 3, 100.380 msec
naive:    100000000 digits -> 3,  98.128 msec
parallel: 100000000 digits -> 3,   7.848 msec

请注意,发布版本中的额外测试不正确,因为getnumericvalue("9") 应该产生9,而不是0

并行版本比简单版本快 12 倍

通过编译器内部函数甚至汇编语言使用 AVX 指令可能会获得更高的性能,但对于非常大的数组,内存带宽似乎是限制因素。

【讨论】:

  • if (pack2 &lt; 0x3030303030303030) break 不依赖于终止 0 之后的字节也是 0 吗?如果您读/写一些动态分配的内存的唯一其他方法是通过char*,那么它可能不是严格混叠UB,但它仍然值得至少评论来进行这种指针转换。如果在 char[] array 上使用它是 UB; may-alias 仅适用于通过char* 访问任何对象,不适用于char objects 使用不同的指针类型访问。
  • 此外,在'9' 上分支的“简单”方式看起来比在数字有规则模式时更好。 (每 16 个元素一个 '9')。如果编译器没有生成无分支代码,现代英特尔分支预测器可能会有效地处理这一问题。
  • @PeterCordes:抱歉,字符串终止测试是假的。我修好了它。事实上,我很惊讶在'9' 的测试中看到如此小的退化。我应该看看大会。
  • @chqrlie "getnumericvalue("9") 应该产生 9,而不是 0。"是的,我修好了。
  • 忘记添加较早的回复:内存带宽限制:现代 x86 台式机/笔记本电脑通常可以在每个核心时钟周期从 RAM 维持 8 个字节(或更多,最近没有做数学计算)。您的展开循环可能会接近这个值,每个 qword 4 uops 不计算指针增量(感谢展开)。如果数据在 L2 甚至 L3 缓存中很热(例如,来自具有合理块大小的 read() 系统调用从页面缓存复制,如果您不使用 mmap),AVX2 或 AVX512 可以明显更快。
猜你喜欢
  • 2014-05-13
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2013-11-24
  • 2017-09-06
  • 2019-06-17
相关资源
最近更新 更多