【发布时间】:2022-01-17 06:49:11
【问题描述】:
我已尝试加快玩具 GEMM 的实施。我处理需要优化的 MM 内核的 32x32 双精度块。我可以访问 AVX2 和 FMA。
我在下面定义了两个代码(在 ASM 中,我为格式的粗糙表示歉意),一个是使用 AVX2 功能,另一个是使用 FMA。
在不涉及微观基准的情况下,我想尝试了解(理论上)为什么 AVX2 实现比 FMA 版本快 1.11 倍。以及如何改进这两个版本。
下面的代码适用于 3000x3000 MM 的双精度数,内核是使用经典的、朴素的 MM 实现的,具有可互换的最深循环。我正在使用 Ryzen 3700x/Zen 2 作为开发 CPU。
我没有尝试过积极展开,因为担心 CPU 可能会耗尽物理寄存器。
AVX2 32x32 MM 内核:
Block 82:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 83:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
vaddpd ymm2, ymm10, ymm2
vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
vaddpd ymm3, ymm11, ymm3
vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
vaddpd ymm0, ymm9, ymm0
vaddpd ymm1, ymm12, ymm1
vaddpd ymm4, ymm10, ymm4
vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
vmulpd ymm8, ymm8, ymmword ptr [rax]
vaddpd ymm5, ymm11, ymm5
add rax, 0x5dc0
vaddpd ymm6, ymm10, ymm6
vaddpd ymm7, ymm8, ymm7
cmp r13, 0x20
jnz 0x140004530 <Block 83>
Block 84:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400044d0 <Block 82>
AVX2/FMA 32x32 MM 内核:
Block 80:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 81:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
add rax, 0x5dc0
cmp r13, 0x20
jnz 0x140004450
Block 82:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400043f0 <Block 80>
【问题讨论】:
标签: assembly matrix-multiplication simd avx micro-optimization