【问题标题】:Numpy roll in several dimensions多个维度的 Numpy 卷
【发布时间】:2021-04-03 00:26:42
【问题描述】:

我需要将一个 3D 数组移动一个 3D 位移向量来进行算法。 截至目前,我正在使用这种(诚然非常丑陋)的方法:

shiftedArray = np.roll(np.roll(np.roll(arrayToShift, shift[0], axis=0)
                                     , shift[1], axis=1),
                             shift[2], axis=2)  

这行得通,但意味着我要掷 3 卷! (根据我的分析,我 58% 的算法时间都花在了这些上面)

来自 Numpy.roll 的文档:

参数:
班次:int

轴:int,可选

参数中没有提到类数组...所以我不能进行多维滚动?

我以为我可以调用这种函数(听起来像是 Numpy 要做的事情):

np.roll(arrayToShift,3DshiftVector,axis=(0,1,2))

也许用我的阵列的扁平版本重塑?但是我该如何计算移位向量?这种转变真的一样吗?

我很惊讶没有找到简单的解决方案,因为我认为这将是一件很常见的事情(好吧,不是 常见的,但是...)

那么我们如何——相对地——通过一个 N 维向量有效地移动一个 ndarray 呢?


注意:这个问题是在 2015 年提出的,当时 numpy 的 roll 方法不支持此功能

【问题讨论】:

    标签: python arrays numpy


    【解决方案1】:

    理论上,使用@Ed Smith 所描述的scipy.ndimage.interpolation.shift 应该可以工作,但由于存在错误(https://github.com/scipy/scipy/issues/1323),它没有给出相当于多次调用np.roll 的结果。


    更新:在 numpy 版本 1.12.0 中,numpy.roll 添加了“多卷”功能。这是一个二维示例,其中第一个轴滚动一个位置,第二个轴滚动三个位置:

    In [7]: x = np.arange(20).reshape(4,5)
    
    In [8]: x
    Out[8]: 
    array([[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14],
           [15, 16, 17, 18, 19]])
    
    In [9]: numpy.roll(x, [1, 3], axis=(0, 1))
    Out[9]: 
    array([[17, 18, 19, 15, 16],
           [ 2,  3,  4,  0,  1],
           [ 7,  8,  9,  5,  6],
           [12, 13, 14, 10, 11]])
    

    这会使下面的代码过时。我会把它留给后代。


    下面的代码定义了一个我称之为multiroll 的函数,它可以满足您的需求。这是一个将其应用于形状为 (500, 500, 500) 的数组的示例:

    In [64]: x = np.random.randn(500, 500, 500)
    
    In [65]: shift = [10, 15, 20]
    

    使用多次调用np.roll 来生成预期结果:

    In [66]: yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2)
    

    使用multiroll生成移位数组:

    In [67]: ymulti = multiroll(x, shift)
    

    验证我们得到了预期的结果:

    In [68]: np.all(yroll3 == ymulti)
    Out[68]: True
    

    对于这种大小的数组,对np.roll 进行三次调用几乎比对multiroll 的调用慢三倍:

    In [69]: %timeit yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2)
    1 loops, best of 3: 1.34 s per loop
    
    In [70]: %timeit ymulti = multiroll(x, shift)
    1 loops, best of 3: 474 ms per loop
    

    这是multiroll的定义:

    from itertools import product
    import numpy as np
    
    
    def multiroll(x, shift, axis=None):
        """Roll an array along each axis.
    
        Parameters
        ----------
        x : array_like
            Array to be rolled.
        shift : sequence of int
            Number of indices by which to shift each axis.
        axis : sequence of int, optional
            The axes to be rolled.  If not given, all axes is assumed, and
            len(shift) must equal the number of dimensions of x.
    
        Returns
        -------
        y : numpy array, with the same type and size as x
            The rolled array.
    
        Notes
        -----
        The length of x along each axis must be positive.  The function
        does not handle arrays that have axes with length 0.
    
        See Also
        --------
        numpy.roll
    
        Example
        -------
        Here's a two-dimensional array:
    
        >>> x = np.arange(20).reshape(4,5)
        >>> x 
        array([[ 0,  1,  2,  3,  4],
               [ 5,  6,  7,  8,  9],
               [10, 11, 12, 13, 14],
               [15, 16, 17, 18, 19]])
    
        Roll the first axis one step and the second axis three steps:
    
        >>> multiroll(x, [1, 3])
        array([[17, 18, 19, 15, 16],
               [ 2,  3,  4,  0,  1],
               [ 7,  8,  9,  5,  6],
               [12, 13, 14, 10, 11]])
    
        That's equivalent to:
    
        >>> np.roll(np.roll(x, 1, axis=0), 3, axis=1)
        array([[17, 18, 19, 15, 16],
               [ 2,  3,  4,  0,  1],
               [ 7,  8,  9,  5,  6],
               [12, 13, 14, 10, 11]])
    
        Not all the axes must be rolled.  The following uses
        the `axis` argument to roll just the second axis:
    
        >>> multiroll(x, [2], axis=[1])
        array([[ 3,  4,  0,  1,  2],
               [ 8,  9,  5,  6,  7],
               [13, 14, 10, 11, 12],
               [18, 19, 15, 16, 17]])
    
        which is equivalent to:
    
        >>> np.roll(x, 2, axis=1)
        array([[ 3,  4,  0,  1,  2],
               [ 8,  9,  5,  6,  7],
               [13, 14, 10, 11, 12],
               [18, 19, 15, 16, 17]])
    
        """
        x = np.asarray(x)
        if axis is None:
            if len(shift) != x.ndim:
                raise ValueError("The array has %d axes, but len(shift) is only "
                                 "%d. When 'axis' is not given, a shift must be "
                                 "provided for all axes." % (x.ndim, len(shift)))
            axis = range(x.ndim)
        else:
            # axis does not have to contain all the axes.  Here we append the
            # missing axes to axis, and for each missing axis, append 0 to shift.
            missing_axes = set(range(x.ndim)) - set(axis)
            num_missing = len(missing_axes)
            axis = tuple(axis) + tuple(missing_axes)
            shift = tuple(shift) + (0,)*num_missing
    
        # Use mod to convert all shifts to be values between 0 and the length
        # of the corresponding axis.
        shift = [s % x.shape[ax] for s, ax in zip(shift, axis)]
    
        # Reorder the values in shift to correspond to axes 0, 1, ..., x.ndim-1.
        shift = np.take(shift, np.argsort(axis))
    
        # Create the output array, and copy the shifted blocks from x to y.
        y = np.empty_like(x)
        src_slices = [(slice(n-shft, n), slice(0, n-shft))
                      for shft, n in zip(shift, x.shape)]
        dst_slices = [(slice(0, shft), slice(shft, n))
                      for shft, n in zip(shift, x.shape)]
        src_blks = product(*src_slices)
        dst_blks = product(*dst_slices)
        for src_blk, dst_blk in zip(src_blks, dst_blks):
            y[dst_blk] = x[src_blk]
    
        return y
    

    【讨论】:

      【解决方案2】:

      我认为scipy.ndimage.interpolation.shift 会做你想做的事,来自docs

      shift : 浮点数或序列,可选

      沿轴的偏移。如果是浮点数,则每个轴的移位都是相同的。如果是一个序列,则 shift 应该为每个轴包含一个值。

      这意味着您可以执行以下操作,

      from scipy.ndimage.interpolation import shift
      import numpy as np
      
      arrayToShift = np.reshape([i for i in range(27)],(3,3,3))
      
      print('Before shift')
      print(arrayToShift)
      
      shiftVector = (1,2,3)
      shiftedarray = shift(arrayToShift,shift=shiftVector,mode='wrap')
      
      print('After shift')
      print(shiftedarray)
      

      哪个产量,

      Before shift
      [[[ 0  1  2]
        [ 3  4  5]
        [ 6  7  8]]
      
       [[ 9 10 11]
        [12 13 14]
        [15 16 17]]
      
       [[18 19 20]
        [21 22 23]
        [24 25 26]]]
      After shift
      [[[16 17 16]
        [13 14 13]
        [10 11 10]]
      
       [[ 7  8  7]
        [ 4  5  4]
        [ 1  2  1]]
      
       [[16 17 16]
        [13 14 13]
        [10 11 10]]]
      

      【讨论】:

      • 不错!我想我对 numpy 的滚动很着迷,并在谷歌上搜索了一下。只有当我写了这个 SO 问题时,我才将其声明为“转变”,这会直接导致我使用这个功能!虽然有点编辑,因为它使用样条进行非轮换(但我使用纯整数),我必须指定 order=0 (否则这个方法比我的旧方法花费了 100 倍!=p)
      • 太好了,很高兴它也很有效。我最近在回答另一个问题(stackoverflow.com/questions/30399534/…)时才发现 scipy shift。 Fortran 具有内在的移位功能,所以我猜 scipy 使用它,这就是它可以比 numpy 快得多的原因。猜猜order=0 避免了样条...
      • scipy 移位使用样条插值。所以它正在计算新值,而不仅仅是移动现有值。
      • 这不等同于多次使用np.roll。例如,请注意您的输入包含0,但您的输出不包含。
      • @WarrenWeckesser 结果的差异是由于不适当的mode='wrap'。需要mode='grid-wrap' 来复制np.roll 的行为
      【解决方案3】:

      np.roll 具有多个维度。做吧

      np.roll(arrayToShift, (shift[0], shift[1], shift[2]), axis=(0,1,2))
      

      它不是很聪明,所以你必须指定轴

      【讨论】:

        【解决方案4】:

        我相信roll 很慢,因为滚动数组不能像切片或重塑操作那样表示为原始数据的视图。所以每次都复制数据。背景见:https://scipy-lectures.github.io/advanced/advanced_numpy/#life-of-ndarray

        首先填充数组(使用 'wrap' 模式)然后使用填充数组上的切片来获取 shiftedArray:http://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html

        【讨论】:

        • 是的,我应该说我一开始就假设数组的(单个)副本可以做,但是连续滚动三次确实意味着 2 个冗余副本,我想避免。
        • 是的,此方法将复制一次,然后切片将只是该数据的视图。
        • np.roll 使用arange 构造indexes 元组,然后是take。所以它正在做较慢的高级索引。理论上,您可以构建自己的 3d 索引元组,并且只应用一次。但这可能是很多工作。
        【解决方案5】:

        wrap模式下的take可以使用,我认为它不会改变内存中的数组。

        这是一个使用@EdSmith 输入的实现:

        arrayToShift = np.reshape([i for i in range(27)],(3,3,3))
        shiftVector = np.array((1,2,3))
        ind = 3-shiftVector    
        np.take(np.take(np.take(arrayToShift,range(ind[0],ind[0]+3),axis=0,mode='wrap'),range(ind[1],ind[1]+3),axis=1,mode='wrap'),range(ind[2],ind[2]+3),axis=2,mode='wrap')
        

        与 OP 相同:

        np.roll(np.roll(np.roll(arrayToShift, shift[0], axis=0) , shift[1], axis=1),shift[2], axis=2) 
        

        给予:

        array([[[21, 22, 23],
                [24, 25, 26],
                [18, 19, 20]],
        
               [[ 3,  4,  5],
                [ 6,  7,  8],
                [ 0,  1,  2]],
        
               [[12, 13, 14],
                [15, 16, 17],
                [ 9, 10, 11]]])
        

        【讨论】:

          猜你喜欢
          • 2013-12-10
          • 1970-01-01
          • 1970-01-01
          • 2016-04-05
          • 2020-07-24
          • 2021-11-12
          • 1970-01-01
          • 2012-02-03
          • 2017-12-23
          相关资源
          最近更新 更多