AVX2 只有 32 位和 64 位的 vgather,没有任何其他类型。假设您可以通过将 LUT 更改为 int 来使其工作,然后使用编译器内在 _mm256_i32gather_epi32
我想您也可以保留当前尺寸。但是,您必须手动将查找转换为适当的格式。如果没有一些广泛的 ASM 或使用内在函数,这将是不可能的。我想这样的事情会起作用:
- 加载 A 和 B 的 8 个字节
- 使用 VPUNPCKLBW 将字节交错到 uint16
- 使用 VPMOVZXBD 将零扩展到 8 个整数
- 在 LUT 上使用 VPGATHERDD。它有一个明确的比例参数。如果将此设置为 2,则可以使用用于 int 的未对齐访问来收集 int16 值。使用 -2 的固定偏移量。然后每个条目的低 2 个字节将是垃圾。确保在 LUT 之前至少 2 个字节可以无段错误地访问(读取:分配 [256*256+1],将实际 LUT 存储在偏移量 1)
- 使用 VPSRAD 移出 2 个垃圾字节并符号扩展为 int32
- 做你的矢量化总和
- 在内循环结束时,使用水平加法将向量减少为标量
是的,我明白为什么编译器无法做到这一点。另请注意,收集通常不会比多个标量访问快得多。但是,它在 GPU(甚至是 Intel 内置)上会很好地工作,因为它们可以使用纹理单元来执行此操作。
在我在英特尔 Coffee Lake 上进行的快速基准测试中,这种实现似乎是正确的,并且大约是原始大 K 的两倍。不过,我不会完全称它为可读的。
#include <cstdint>
// using std::uint8_t, std::int16_t
#include <immintrin.h>
// using intrinsics up to AVX2
void lutmat(int M, int N, int K,
const std::uint8_t* A,
const std::uint8_t* B,
int* C,
const std::int16_t* lut)
{
// this needs to be valid! lut[-1] must not segfault!
const int* lut32 = reinterpret_cast<const int*>(lut - 1);
for(std::ptrdiff_t i = 0; i < M; ++i) {
const std::uint8_t* Ai = A + i * K;
for(std::ptrdiff_t j = 0; j < N; ++j) {
const std::uint8_t* Bj = B + j * K;
__m256i sum8 = _mm256_set1_epi32(0);
std::ptrdiff_t k;
for(k = 0; K - k >= 16; k += 16) {
// fetch 16 bytes per input matrix
__m128i ak = _mm_loadu_si128((const __m128i_u*) (Ai + k));
__m128i bk = _mm_loadu_si128((const __m128i_u*) (Bj + k));
// interleave into 2 x 8 16 bit values. These are our indices
__m128i interleaved_lo = _mm_unpacklo_epi8(bk, ak);
__m128i interleaved_hi = _mm_unpackhi_epi8(bk, ak);
// zero extend indices to 32 bit values
__m256i extended_lo = _mm256_cvtepu16_epi32(interleaved_lo);
__m256i extended_hi = _mm256_cvtepu16_epi32(interleaved_hi);
// do unaligned gather of 32 bit values.
// Valid bytes are in upper 2 byte per int due to the offset in lut32
__m256i gathered_lo = _mm256_i32gather_epi32(lut32, extended_lo, 2 /*scale*/);
__m256i gathered_hi = _mm256_i32gather_epi32(lut32, extended_hi, 2 /*scale*/);
// sign-extend and remove garbage in lower 2 byte
__m256i corrected_lo = _mm256_srai_epi32(gathered_lo, 16);
__m256i corrected_hi = _mm256_srai_epi32(gathered_hi, 16);
// add to 8 partial sums
sum8 = _mm256_add_epi32(sum8, corrected_lo);
sum8 = _mm256_add_epi32(sum8, corrected_hi);
}
if(K - k >= 8) {
// single iteration using just 8 fetched values
__m128i ak = _mm_loadl_epi64((const __m128i*) (Ai + k));
__m128i bk = _mm_loadl_epi64((const __m128i*) (Bj + k));
__m128i interleaved_lo = _mm_unpacklo_epi8(bk, ak);
__m256i extended_lo = _mm256_cvtepu16_epi32(interleaved_lo);
__m256i gathered_lo = _mm256_i32gather_epi32(lut32, extended_lo, 2);
__m256i corrected_lo = _mm256_srai_epi32(gathered_lo, 16);
sum8 = _mm256_add_epi32(sum8, corrected_lo);
k += 8;
}
// reduce 8 to 4 partial sums
__m128i low4 = _mm256_castsi256_si128(sum8);
__m128i high4 = _mm256_extracti128_si256(sum8, 1);
__m128i sum4 = _mm_add_epi32(low4, high4);
if(K - k >= 4) {
// single iteration using 4 fetched values
__m128i ak = _mm_cvtsi32_si128(*(const int*) (Ai + k));
__m128i bk = _mm_cvtsi32_si128(*(const int*) (Bj + k));
__m128i interleaved = _mm_unpacklo_epi8(bk, ak);
__m128i extended = _mm_cvtepu16_epi32(interleaved);
__m128i gathered = _mm_i32gather_epi32(lut32, extended, 2);
__m128i corrected = _mm_srai_epi32(gathered, 16);
sum4 = _mm_add_epi32(sum4, corrected);
k += 4;
}
// reduce partial sums to 1 scalar sum
__m128i high2 = _mm_unpackhi_epi64(sum4, sum4);
__m128i sum2 = _mm_add_epi32(high2, sum4);
__m128i high1 = _mm_shuffle_epi32(sum2, _MM_SHUFFLE(3,3,1,1));
int sum = _mm_cvtsi128_si32(_mm_add_epi32(sum2, high1));
// add the last few entries
// we use a separate partial sum to avoid a dependency chain through
// the reduction above
int tail = 0;
for(; k < K; ++k) {
uint8_t a = Ai[k];
uint8_t b = Bj[k];
tail += lut[a * 256 + b];
}
C[i*N+j] = sum + tail;
}
}
}