【问题标题】:How to find the first nonzero in an array efficiently?如何有效地找到数组中的第一个非零值?
【发布时间】:2021-12-27 18:03:31
【问题描述】:

假设我们要快速找到数组中第一个非零元素的索引,效果如下

fn leading_zeros(arr: &[u32]) -> Option<usize> {
    arr.iter().position(|&x| x != 0)
}

但是,rustc 将其编译为 seen here 的一项检查。 可以通过使用u128 类型检查单词 4 x 4 来加快速度,如下所示。这使我的机器上的速度提高了大约 3 倍。

fn leading_zeros_wide(arr: &[u32]) -> Option<usize> {
    let (beg, mid, _) = unsafe { arr.align_to::<u128>() };

    beg.iter().position(|&x| x != 0).or_else(|| {
        let left = beg.len() + 4 * mid.iter().position(|&x| x != 0).unwrap_or(mid.len());
        arr[left..].iter().position(|&x| x != 0).map(|p| p + left)
    })
}

有没有办法让它更快?


这是我用来确定 3 倍加速的基准:

#![feature(test)]
extern crate test;

fn v() -> Box<[u32]> {
    std::iter::repeat(0).take(1000).collect()
}

// Assume `leading_zeros` and `leading_zeros_wide` are defined here.

#[bench]
fn bench_leading_zeros(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros(&v[3..]))
}

#[bench]
fn bench_leading_zeros_wide(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros_wide(&v[3..]))
}

【问题讨论】:

  • @JohnKugelman 我没有使用 end 参数,因为切片 arr[left..] 包含该部分
  • @JohnKugelman mm 你的解释我明白为什么 end 被忽略了,我认为代码应该得到一些注释或更好的变量命名。现在看我没问题。也就是说,既然这个问题说这更快,我认为如果不是要求,在问题中测试它的基准代码将是一个加分。
  • 我认为docs.rs/memx/latest/memx/fn.memnechr.html应该更快更可靠
  • 谢谢大家!可悲的是,memx crate 目前似乎对memnechr 有一个错误(至少对于 0.1.18)
  • 我看到即使将 SIMD 功能指定为编译器选项,您的优化版本仍然不是 SIMD:rust.godbolt.org/z/8scnKToq8 这意味着它可以进一步优化。显然有一种方法可以直接使用 CPU 内部函数:x86arm。对不起,我不会提供这个解决方案,我不知道 Rust(我通过[simd] 标签看到了这个问题)

标签: rust simd


【解决方案1】:

64 位:https://rust.godbolt.org/z/rsxh8P8Er

32 位:https://rust.godbolt.org/z/3P3ejsnh1

我对 Rust 和 Assembly 有一点经验,但我添加了一些测试。

#[cfg(target_feature = "avx2")]
pub mod avx2 {
    #[cfg(target_arch = "x86")]
    use std::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use std::arch::x86_64::*;

    fn first_nonzero_tiny(arr: &[u32]) -> Option<usize> {
        arr.iter().position(|&x| x != 0)
    }

    fn find_u32_zeros_8elems(arr: &[u32], offset: isize) -> i32 {
        unsafe {
            let ymm0 = _mm256_setzero_si256();
            let mut ymm1 = _mm256_loadu_si256(arr.as_ptr().offset(offset) as *const __m256i);
            ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
            let ymm2 = _mm256_castsi256_ps(ymm1);
            _mm256_movemask_ps(ymm2)
        }
    }

    pub fn first_nonzero(arr: &[u32]) -> Option<usize> {
        let size = arr.len();
        if size < 8 {
            return first_nonzero_tiny(arr);
        }

        let mut i: usize = 0;
        let simd_size = size / 8 * 8;
        while i < simd_size {
            let mask: i32 = find_u32_zeros_8elems(&arr, i as isize);
            //println!("mask = {}", mask);
            if mask != 255 {
                return Some((mask.trailing_ones() as usize) + i);
            }
            i += 8;
            //println!("i = {}", i);
        }

        let last_chunk = size - 8;
        let mask: i32 = find_u32_zeros_8elems(&arr, last_chunk as isize);
        if mask != 255 {
            return Some((mask.trailing_ones() as usize) + last_chunk);
        }

        None
    }
}

use avx2::first_nonzero;

pub fn main() {
    let v = [0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [2];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [1, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [0, 1, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(1));

    let v = [0, 0, 1, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(2));

    let v = [0, 0, 0, 1, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(3));

    let v = [0, 0, 0, 0, 1, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(4));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(6));

    let v = [0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(7));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(8));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(16));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(15));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 3, 4, 5];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(14));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(17));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(18));
}

【讨论】:

  • 看起来不错。对于任何大小> = 8,也应该可以使用 SIMD 处理尾部,最后一个向量是在数组末尾结束的未对齐负载。 (_mm256_loadu_si256;很惊讶你在循环中使用了需要对齐的load 而没有评论该输入要求)。 (将 SIMD 内容包装在辅助函数中可能会很方便,而不是标量回退)。
  • 如果您使用i += 8 并在将结果转换为__m256i* 之前将其用作指向i32 的指针的偏移量,则处理起来会更方便,就像C 风格的@987654328 @ 而不是 i + (const __m128i*)arr。即使对于您当前的代码,这也可以让您使用结尾的 i 值,而不是 n*8 .. arr.len()。虽然那时你必须做i &lt; n-7;您正在使用 n/=8; 解决向量过冲问题,所以此时它只是为手动 SIMD 进行数组索引的两种同样好的方式。
  • 谢谢@PeterCordes,我编辑了我的答案,现在它使用_mm256_loadu_si256 而不是_mm256_load_si256
  • @PeterCordes,显然target-feature=+avx2,bmi,bmi2 没有启用tzcnt,你必须使用target-feature=+avx2,+bmi(这里似乎不需要bmi2)
  • @IgorZhukov 非常感谢!将在我的机器上跟进测试结果(这需要时间,因为我需要将其移植到 aarch64)
【解决方案2】:

这是一个解决方案,它比基线更快,但可能仍然有很多问题。

以下实现了基线first_nonzero 的 7.5 倍。

/// Finds the position of the first nonzero element in a given slice which
/// contains a nonzero.
///
/// # Safety
///
/// The caller *has* to ensure that the input slice has a nonzero.
unsafe fn first_nonzero_padded(arr: &[u32]) -> usize {
    let (beg, mid, _) = arr.align_to::<u128>();
    beg.iter().position(|&x| x != 0).unwrap_or_else(|| {
        let left = beg.len()
            + 4 * {
                let mut p: *const u128 = mid.as_ptr();
                loop {
                    if *p.offset(0) != 0 { break p.offset(0); }
                    if *p.offset(1) != 0 { break p.offset(1); }
                    if *p.offset(2) != 0 { break p.offset(2); }
                    if *p.offset(3) != 0 { break p.offset(3); }
                    if *p.offset(4) != 0 { break p.offset(4); }
                    if *p.offset(5) != 0 { break p.offset(5); }
                    if *p.offset(6) != 0 { break p.offset(6); }
                    if *p.offset(7) != 0 { break p.offset(7); }
                    p = p.offset(8);
                }.offset_from(mid.as_ptr()) as usize
            };
        if let Some(p) = arr[left..].iter().position(|&x| x != 0) {
            left + p
        } else {
            core::hint::unreachable_unchecked()
        }
    })
}

【讨论】:

  • 有没有办法从这个答案的the first revision 编译(在 Godbolt 上)尝试的 SIMD 版本? use core_simd::u64x2; 等等? godbolt.org/z/E6ozdhdYc 不适用于我的 rustc 每晚。如果速度较慢,很可能您的 mask8x8::from_array([ *p.offset(00) != ZERO,07 ]) 没有编译为单个 SSE4.1 pcmpeqq 或其他任何东西。 IDK 是否会花费大量标量工作将 8 个 2 位比较结果打包到单个 mask8x8 中,或者更糟糕地将这些 2 位结果布尔化为 1 位结果?
  • 但无论如何,将其描述为pcmpeqd / tzcnt 几乎可以肯定是虚假的,所以是的,难怪你从你的答案中删除了它:P 而且我对 16 日早出并不感到惊讶-byte 块好一点;你希望内部循环花费大量的工作来准备循环之后的东西,以解决非零元素所在的位置。例如如果您期望长时间运行零,您甚至可以将多个向量组合在一起,然后稍后单独重新检查它们。 (在缓存行大小的块中工作很好,特别是如果您的数据按 64 对齐)
  • 您当前的代码正在执行两个 64 位块的标量 OR,并在由此设置的 FLAGS 上进行分支。 godbolt.org/z/6fMEvveMb
猜你喜欢
  • 2018-04-26
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多