【问题标题】:Fast 2D array lookup of int16_t LUT using AVX2 or AVX512使用 AVX2 或 AVX512 快速查找 int16_t LUT 的二维数组
【发布时间】:2021-11-01 12:47:05
【问题描述】:

我想加快在二维数组中执行一系列查找的算法。基本上它就像矩阵乘法以相同的顺序访问两个数组,但没有乘法,只有查找(可能 B 矩阵这样存储更有效?)。所以查找表是 256x256,具有 int16_t 值,A、B 矩阵具有 int8_t 值。代码如下:

for(int i = 0; i < M; ++i) {
    for(int j = 0; j < N; ++j) {
        int temp = 0;
        for(int k = 0; k < K; ++k) {
            uint8_t a = A[i*K+k];
            uint8_t b = B[j*K+k];
            temp += lut[a][b];
        }
        C[i*N+j] = temp; 
    }
}

我知道 AVX 可以进行并行查找,但我找不到在像 256x256 这样的大型 2D 矩阵中执行此操作的方法。支持 AVX512。此外,使用了 g++,也欢迎任何其他优化。

提前致谢

【问题讨论】:

    标签: c++ optimization avx lookup-tables avx512


    【解决方案1】:

    AVX2 只有 32 位和 64 位的 vgather,没有任何其他类型。假设您可以通过将 LUT 更改为 int 来使其工作,然后使用编译器内在 _mm256_i32gather_epi32

    我想您也可以保留当前尺寸。但是,您必须手动将查找转换为适当的格式。如果没有一些广泛的 ASM 或使用内在函数,这将是不可能的。我想这样的事情会起作用:

    1. 加载 A 和 B 的 8 个字节
    2. 使用 VPUNPCKLBW 将字节交错到 uint16
    3. 使用 VPMOVZXBD 将零扩展到 8 个整数
    4. 在 LUT 上使用 VPGATHERDD。它有一个明确的比例参数。如果将此设置为 2,则可以使用用于 int 的未对齐访问来收集 int16 值。使用 -2 的固定偏移量。然后每个条目的低 2 个字节将是垃圾。确保在 LUT 之前至少 2 个字节可以无段错误地访问(读取:分配 [256*256+1],将实际 LUT 存储在偏移量 1)
    5. 使用 VPSRAD 移出 2 个垃圾字节并符号扩展为 int32
    6. 做你的矢量化总和
    7. 在内循环结束时,使用水平加法将向量减少为标量

    是的,我明白为什么编译器无法做到这一点。另请注意,收集通常不会比多个标量访问快得多。但是,它在 GPU(甚至是 Intel 内置)上会很好地工作,因为它们可以使用纹理单元来执行此操作。

    在我在英特尔 Coffee Lake 上进行的快速基准测试中,这种实现似乎是正确的,并且大约是原始大 K 的两倍。不过,我不会完全称它为可读的。

    #include <cstdint>
    // using std::uint8_t, std::int16_t
    #include <immintrin.h>
    // using intrinsics up to AVX2
    
    void lutmat(int M, int N, int K,
                const std::uint8_t* A,
                const std::uint8_t* B,
                int* C,
                const std::int16_t* lut)
    {
        // this needs to be valid! lut[-1] must not segfault!
        const int* lut32 = reinterpret_cast<const int*>(lut - 1);
        for(std::ptrdiff_t i = 0; i < M; ++i) {
            const std::uint8_t* Ai = A + i * K;
            for(std::ptrdiff_t j = 0; j < N; ++j) {
                const std::uint8_t* Bj = B + j * K;
                __m256i sum8 = _mm256_set1_epi32(0);
                std::ptrdiff_t k;
                for(k = 0; K - k >= 16; k += 16) {
                    // fetch 16 bytes per input matrix
                    __m128i ak = _mm_loadu_si128((const __m128i_u*) (Ai + k));
                    __m128i bk = _mm_loadu_si128((const __m128i_u*) (Bj + k));
                    // interleave into 2 x 8 16 bit values. These are our indices
                    __m128i interleaved_lo = _mm_unpacklo_epi8(bk, ak);
                    __m128i interleaved_hi = _mm_unpackhi_epi8(bk, ak);
                    // zero extend indices to 32 bit values
                    __m256i extended_lo = _mm256_cvtepu16_epi32(interleaved_lo);
                    __m256i extended_hi = _mm256_cvtepu16_epi32(interleaved_hi);
                    // do unaligned gather of 32 bit values.
                    // Valid bytes are in upper 2 byte per int due to the offset in lut32
                    __m256i gathered_lo = _mm256_i32gather_epi32(lut32, extended_lo, 2 /*scale*/);
                    __m256i gathered_hi = _mm256_i32gather_epi32(lut32, extended_hi, 2 /*scale*/);
                    // sign-extend and remove garbage in lower 2 byte
                    __m256i corrected_lo = _mm256_srai_epi32(gathered_lo, 16);
                    __m256i corrected_hi = _mm256_srai_epi32(gathered_hi, 16);
                    // add to 8 partial sums
                    sum8 = _mm256_add_epi32(sum8, corrected_lo);
                    sum8 = _mm256_add_epi32(sum8, corrected_hi);
                }
                if(K - k >= 8) {
                    // single iteration using just 8 fetched values
                    __m128i ak = _mm_loadl_epi64((const __m128i*) (Ai + k));
                    __m128i bk = _mm_loadl_epi64((const __m128i*) (Bj + k));
                    __m128i interleaved_lo = _mm_unpacklo_epi8(bk, ak);
                    __m256i extended_lo = _mm256_cvtepu16_epi32(interleaved_lo);
                    __m256i gathered_lo = _mm256_i32gather_epi32(lut32, extended_lo, 2);
                    __m256i corrected_lo = _mm256_srai_epi32(gathered_lo, 16);
                    sum8 = _mm256_add_epi32(sum8, corrected_lo);
                    k += 8;
                }
                // reduce 8 to 4 partial sums
                __m128i low4 = _mm256_castsi256_si128(sum8);
                __m128i high4  = _mm256_extracti128_si256(sum8, 1);
                __m128i sum4 = _mm_add_epi32(low4, high4);
                if(K - k >= 4) {
                    // single iteration using 4 fetched values
                    __m128i ak = _mm_cvtsi32_si128(*(const int*) (Ai + k));
                    __m128i bk = _mm_cvtsi32_si128(*(const int*) (Bj + k));
                    __m128i interleaved = _mm_unpacklo_epi8(bk, ak);
                    __m128i extended = _mm_cvtepu16_epi32(interleaved);
                    __m128i gathered = _mm_i32gather_epi32(lut32, extended, 2);
                    __m128i corrected = _mm_srai_epi32(gathered, 16);
                    sum4 = _mm_add_epi32(sum4, corrected);
                    k += 4;
                }
                // reduce partial sums to 1 scalar sum
                __m128i high2  = _mm_unpackhi_epi64(sum4, sum4);
                __m128i sum2 = _mm_add_epi32(high2, sum4);
                __m128i high1  = _mm_shuffle_epi32(sum2, _MM_SHUFFLE(3,3,1,1));
                int sum = _mm_cvtsi128_si32(_mm_add_epi32(sum2, high1));
                // add the last few entries
                // we use a separate partial sum to avoid a dependency chain through
                // the reduction above
                int tail = 0;
                for(; k < K; ++k) {
                    uint8_t a = Ai[k];
                    uint8_t b = Bj[k];
                    tail += lut[a * 256 + b];
                }
                C[i*N+j] = sum + tail;
            }
        }
    }
    

    【讨论】:

    • AVX2 已经收集了 32 位和 64 位整数 (vpgatherXX) 以及浮点数和双精度数的负载。
    • @PaulR 谢谢!我改变了文字
    • 与其尝试以 4 种不同的方式对广泛的负载进行洗牌,最好进行 128 位或 64 位负载以将 vpmovzxbd 零扩展名馈送到 __m256i__m512i 向量中.但是,是的,您的步骤看起来很合理。如果 LUT 已知为非负数,则零扩展比将 16 位符号扩展为 32 位元素更便宜。对于班次,您不需要可变班次,只需简单的旧 vpslldvpsrad 立即 16。
    • 哦,是的,对。 IIRC 之前有一个关于 16 位收集的 SO 问题,我记得之前出现过 -2 单班次的想法。就避免缓存行拆分而言,相对于 64 字节行,跨入索引 0 的前一个缓存行与跨入索引 31 的下一行一样糟糕,我猜。除了在 LUT 的末尾比在开头更容易填充,以避免在之前而不是之后读取到未映射的页面。 (无论哪种方式,它都会使其触及额外的页面,超出 256^2 x 2B = 128kiB。)
    • 很好的编辑。可以展开循环以使用unpacklo_epi8unpackhi_epi8 来使用128 位负载的两半。 (我之前没有想到这一点)。在 hsum 中,我建议使用 extracti128 而不是 f,因为这是整数数据,而且我们已经在使用 AVX2 内在函数。在大多数 CPU 上无关紧要。此外,_mm_shufflelo_epi16 在 AVX2 CPU 上可能很傻,也许最好只使用简单的 32 位元素 shuffle,如 _mm_shuffle_epi32(v, _MM_SHUFFLE(3,3,1,1)),它在 AVX2 CPU 上运行相同,并且更具人类可读性,并使 clang 的窥视孔优化为 @987654336 @ 更清楚。
    猜你喜欢
    • 2014-01-16
    • 1970-01-01
    • 2021-01-12
    • 2016-06-17
    • 2020-05-23
    • 1970-01-01
    • 2014-10-27
    • 1970-01-01
    相关资源
    最近更新 更多