【发布时间】:2021-04-23 09:12:58
【问题描述】:
我有以下结构,它存储键和通用用户指定的值:
typedef struct {
uint32_t len;
uint32_t cap;
int32_t *keys;
void *vals;
} dict;
现在我想创建一个函数,它遍历 keys 并返回相应的 value。
非 SIMD 版本:
void*
dict_find(dict *d, int32_t k, size_t s) {
size_t i;
i = 0;
while (i < d->len) {
if (d->keys[i] == k) {
void *p;
p = (uint8_t*)d->vals + i * s;
return p;
}
++i;
}
return NULL;
}
我尝试对上面的 sn-p 进行矢量化处理,结果如下:
void*
dict_find_simd(dict *d, int32_t k, size_t s) {
__m256i ymm0;
ymm0 = _mm256_broadcastd_epi32(*(__m128i*)&k);
__m256i ymm1;
uint32_t i;
int m;
uint8_t b;
i = 0;
while (i < d->len) { // [d->len] is aligned in 32 byte box.
ymm1 = _mm256_load_si256((__m256i*)(d->keys + i));
ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
m = _mm256_movemask_epi8(ymm1);
b = __builtin_ctz(m) >> 2;
i += (8 + b * d->len); // Artificially break the loop.
// Remember [i] stores the modified value.
}
if (i <= d->len)
return NULL;
i -= (8 + b * d->len); // Restore the modified value.
i += b;
void *p;
p = (uint8_t*)d->vals + i * s;
return p;
}
该功能似乎工作正常(没有进行太多测试)?
但是,有两个问题:
- 注意:我正在检查
i > d->len是否返回指针。i可以溢出,它会在那里返回NULL。我该如何解决这个问题? - 您可能注意到我使用
_mm256_movemask_epi8和__builtin_ctz的组合来获取找到的键的索引。有没有更好的方法(可能是一条获得非零值位置的指令)来做到这一点(没有 AVX512)?
【问题讨论】:
-
为什么不简单地
for(i=0; i<d-len; i+=8) { /* SIMD stuff */ if(b) break;}? -
我认为使用 SIMD 进行分支是“不可行的”。也许我错了。
-
其实你也可以用
if(m) break;break,在循环之后做b的计算。如果您要处理许多非常小的数组(通常适合几个寄存器),则分支将是低效的。在这种情况下,会有更有效的解决方案。 (这个问题肯定有重复,但我暂时不想找) -
人为地修改循环计数器通常对性能来说是一个更大的问题,因为它破坏了循环展开、预加载数据等的许多可能性。如果你只是引入一个新的会稍微好一点变量
break_next并将其放入循环条件中。不过你也可以直接break。 -
如果
m是0,__builtin_ctz(m)是未定义的行为。使用_tzcnt_u32(m)是您想要明确定义的行为。大多数(全部?)AVX2 CPU 都有 tzcnt。在实践中,__bultin_ctz(m)将对大多数编译器使用 tzcnt,因此无论如何都会产生明确定义的结果。但是,是的,同意@chtz,循环计数器更新看起来很糟糕,我认为在上一个比较和下一个迭代的加载地址之间创建一个依赖链,所以你的瓶颈是 latency 而不是吞吐量。跨度>