【问题标题】:GEMM kernel implemented using AVX2 is faster than AVX2/FMA on a Zen 2 CPU使用 AVX2 实现的 GEMM 内核比 Zen 2 CPU 上的 AVX2/FMA 更快
【发布时间】:2022-01-17 06:49:11
【问题描述】:

我已尝试加快玩具 GEMM 的实施。我处理需要优化的 MM 内核的 32x32 双精度块。我可以访问 AVX2 和 FMA。

我在下面定义了两个代码(在 ASM 中,我为格式的粗糙表示歉意),一个是使用 AVX2 功能,另一个是使用 FMA。

在不涉及微观基准的情况下,我想尝试了解(理论上)为什么 AVX2 实现比 FMA 版本快 1.11 倍。以及如何改进这两个版本。

下面的代码适用于 3000x3000 MM 的双精度数,内核是使用经典的、朴素的 MM 实现的,具有可互换的最深循环。我正在使用 Ryzen 3700x/Zen 2 作为开发 CPU。

我没有尝试过积极展开,因为担心 CPU 可能会耗尽物理寄存器。

AVX2 32x32 MM 内核:

Block 82:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 83:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
    vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
    vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
    vaddpd ymm2, ymm10, ymm2    
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
    vaddpd ymm3, ymm11, ymm3    
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
    vaddpd ymm0, ymm9, ymm0   
    vaddpd ymm1, ymm12, ymm1
    vaddpd ymm4, ymm10, ymm4
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
    vmulpd ymm8, ymm8, ymmword ptr [rax]       
    vaddpd ymm5, ymm11, ymm5    
    add rax, 0x5dc0 
    vaddpd ymm6, ymm10, ymm6
    vaddpd ymm7, ymm8, ymm7 
    cmp r13, 0x20
    jnz 0x140004530 <Block 83>
Block 84:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400044d0 <Block 82>

AVX2/FMA 32x32 MM 内核:

Block 80:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 81:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
    vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
    vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
    vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
    vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
    vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
    vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
    vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
    add rax, 0x5dc0 
    cmp r13, 0x20   
    jnz 0x140004450
Block 82:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400043f0 <Block 80>

【问题讨论】:

    标签: assembly matrix-multiplication simd avx micro-optimization


    【解决方案1】:

    Zen2 对于 vaddpd 有 3 个周期延迟,对于 vfma...pd 有 5 个周期延迟。 (https://uops.info/)。

    您的具有 8 个累加器的代码具有足够的 ILP,您预计每个时钟接近两个 FMA,大约每 5 个时钟 8 个(如果没有其他瓶颈),这比 10/5 理论最大值略少。

    vaddpdvmulpd 实际上分别运行在 Zen2 上的 不同 端口(与 Intel 不同)、端口 FP2/3 和 FP0/1,因此理论上可以支持 2/clock @987654327 @ vmulpd。由于循环携带依赖的延迟更短,如果调度不让一个 dep 链落后,8 个累加器足以隐藏vaddpd 延迟。 (但至少乘法不会从中窃取循环。)

    Zen2 的前端是 5 条指令宽(如果有任何多 uop 指令,则为 6 uop),它可以将内存源指令解码为单个 uop。因此,它很可能会使用非 FMA 版本进行每次乘法和加法的 2/时钟。

    如果您可以展开 10 或 12,这可能会隐藏足够的 FMA 延迟并使其与非 FMA 版本相同,但功耗更低,并且对在其他逻辑内核上运行的代码更友好。 (10 = 5 x 2 只是几乎足够了,这意味着任何调度缺陷都会在关键路径上的 dep 链上失去进度。有关英特尔的一些测试,请参阅 Why does mulss take only 3 cycles on Haswell, different from Agner's instruction tables? (Unrolling FP loops with multiple accumulators)。)

    (相比之下,英特尔 Skylake 在相同端口上运行 vaddpd/vmulpd,延迟与 vfma...pd 相同,延迟均为 4c,吞吐量为 0.5c。)

    我没有仔细查看您的代码,但是 10 个 YMM 向量可能是在接触两对缓存线与接触总共 5 行之间的权衡,如果空间预取器尝试完成对齐的对,这可能会更糟.或者可能没问题。 12个YMM向量是三对,应该没问题。

    根据矩阵大小,无序执行可能能够在外循环的单独迭代之间重叠内循环 dep 链,特别是如果循环退出条件可以更快地执行并解决错误预测(如果有的话)而FP工作仍在进行中。这是一个优势,即同一工作的总微指令更少,有利于 FMA。

    【讨论】:

    • 从 Skylake 开始,英特尔处理器可以向 p0 或 p1 发出 AVX2 VADDPD、VMULPD 和 VFMA 指令——所有这些都具有 4 个周期的延迟和 2 个指令/周期的峰值吞吐量。只有 Haswell/Broadwell 一代有奇怪的限制,即 p1 后面的 FMA 单元无法接受 ADD 操作。 (可以将这些转换为 FMA 并提高吞吐量。)
    • @JohnDMcCalpin 在 Zen 2 上,当代码像 OP 的示例一样进行 50% 的加法和 50% 的乘法时,总限制为 4/clock。 CPU 可以调度多达 6 条指令/时钟,因此 OP 的数学运算的 4/时钟仍然适合该特定瓶颈。对于 Intel,无论是加法、乘法还是 fma 指令,它都是 2/clock。在 Intel 上,FMA 可能是这些周期的最佳用途。
    • 我正在使用超线程,这会因为 ROB 分区而影响结果吗?
    • @EtienneM:静态分区意味着更小的乱序窗口,但并没有真正改变分析。 ROB 的一半应该仍然足以覆盖任何独立的工作。更重要的是两个线程将在相同的端口上相互竞争,所以如果两个线程都使用 8x 展开循环,你实际上有 16 个 dep 链,这应该足以让 10 个 FMA (5 / 0.5) 保持运行。
    猜你喜欢
    • 2016-06-17
    • 2019-10-02
    • 1970-01-01
    • 2018-12-28
    • 2014-06-23
    • 2019-01-11
    • 1970-01-01
    • 1970-01-01
    • 2018-04-24
    相关资源
    最近更新 更多