【问题标题】:multiply numpy ndarray with 1d array along a given axis沿给定轴将 numpy ndarray 与 1d 数组相乘
【发布时间】:2015-07-13 22:32:11
【问题描述】:

似乎我迷失在一些可能很愚蠢的事情中。 我有一个 n 维 numpy 数组,我想将它与沿某个维度的向量(一维数组)相乘(可以改变!)。 例如,假设我想沿第一个数组的轴 0 将 2d 数组乘以 1d 数组,我可以这样做:

a=np.arange(20).reshape((5,4))
b=np.ones(5)
c=a*b[:,np.newaxis]

很简单,但我想将此想法扩展到 n 维(对于 a,而 b 始终为 1d)和任何轴。换句话说,我想知道如何在正确的位置使用 np.newaxis 生成切片。假设 a 是 3d 并且我想沿轴 = 1 相乘,我想生成正确给出的切片:

c=a*b[np.newaxis,:,np.newaxis]

即给定 a 的维数(比如 3),以及我想要乘以的轴(比如 axis=1),我如何生成和传递切片:

np.newaxis,:,np.newaxis

谢谢。

【问题讨论】:

  • 我在 ndarray 上有数据,我想将数据与某个轴上的过滤器相乘。

标签: python arrays numpy slice


【解决方案1】:

解决方案代码 -

import numpy as np

# Given axis along which elementwise multiplication with broadcasting 
# is to be performed
given_axis = 1

# Create an array which would be used to reshape 1D array, b to have 
# singleton dimensions except for the given axis where we would put -1 
# signifying to use the entire length of elements along that axis  
dim_array = np.ones((1,a.ndim),int).ravel()
dim_array[given_axis] = -1

# Reshape b with dim_array and perform elementwise multiplication with 
# broadcasting along the singleton dimensions for the final output
b_reshaped = b.reshape(dim_array)
mult_out = a*b_reshaped

步骤演示的示例运行 -

In [149]: import numpy as np

In [150]: a = np.random.randint(0,9,(4,2,3))

In [151]: b = np.random.randint(0,9,(2,1)).ravel()

In [152]: whos
Variable   Type       Data/Info
-------------------------------
a          ndarray    4x2x3: 24 elems, type `int32`, 96 bytes
b          ndarray    2: 2 elems, type `int32`, 8 bytes

In [153]: given_axis = 1

现在,我们要沿 given axis = 1 执行元素乘法。让我们创建dim_array

In [154]: dim_array = np.ones((1,a.ndim),int).ravel()
     ...: dim_array[given_axis] = -1
     ...: 

In [155]: dim_array
Out[155]: array([ 1, -1,  1])

最后,重塑 b 并执行元素乘法:

In [156]: b_reshaped = b.reshape(dim_array)
     ...: mult_out = a*b_reshaped
     ...: 

再次查看whos信息并特别注意b_reshapedmult_out

In [157]: whos
Variable     Type       Data/Info
---------------------------------
a            ndarray    4x2x3: 24 elems, type `int32`, 96 bytes
b            ndarray    2: 2 elems, type `int32`, 8 bytes
b_reshaped   ndarray    1x2x1: 2 elems, type `int32`, 8 bytes
dim_array    ndarray    3: 3 elems, type `int32`, 12 bytes
given_axis   int        1
mult_out     ndarray    4x2x3: 24 elems, type `int32`, 96 bytes

【讨论】:

  • 好吧,我的错,我没有提到这一点:我无法生成与 a 匹配的适当大小的 b 副本,因为 a 可能非常非常大。
  • 嘿,不,对不起,这实际上是一个解决方案,我误解了它。太好了,谢谢!
  • @AJC 真的没关系!删除我之前的评论。
  • 我总是对 numpy 的强大程度感到惊讶。凉爽的!再次感谢。
  • @AJC 我在过去的一个月里发现了同样的情况! :)
【解决方案2】:

您可以构建一个切片对象,并在其中选择所需的维度:

import numpy as np

a = np.arange(18).reshape((3,2,3))
b = np.array([1,3])

ss = [None for i in range(a.ndim)]
ss[1] = slice(None)    # set the dimension along which to broadcast

print ss  #  [None, slice(None, None, None), None]

c = a*b[ss]

【讨论】:

    【解决方案3】:

    我在做一些数值计算时也有类似的需求。

    假设我们有两个数组(A 和 B)和一个用户指定的“轴”。 A 是一个多维数组。 B 是一维数组。

    基本思想是扩展 B 使 A 和 B 具有相同的形状。这是解决方案代码

    import numpy as np
    from numpy.core._internal import AxisError
    
    def multiply_along_axis(A, B, axis):
        A = np.array(A)
        B = np.array(B)
        # shape check
        if axis >= A.ndim:
            raise AxisError(axis, A.ndim)
        if A.shape[axis] != B.size:
            raise ValueError("'A' and 'B' must have the same length along the given axis")
        # Expand the 'B' according to 'axis':
        # 1. Swap the given axis with axis=0 (just need the swapped 'shape' tuple here)
        swapped_shape = A.swapaxes(0, axis).shape
        # 2. Repeat:
        # loop through the number of A's dimensions, at each step:
        # a) repeat 'B':
        #    The number of repetition = the length of 'A' along the 
        #    current looping step; 
        #    The axis along which the values are repeated. This is always axis=0,
        #    because 'B' initially has just 1 dimension
        # b) reshape 'B':
        #    'B' is then reshaped as the shape of 'A'. But this 'shape' only 
        #     contains the dimensions that have been counted by the loop
        for dim_step in range(A.ndim-1):
            B = B.repeat(swapped_shape[dim_step+1], axis=0)\
                 .reshape(swapped_shape[:dim_step+2])
        # 3. Swap the axis back to ensure the returned 'B' has exactly the 
        # same shape of 'A'
        B = B.swapaxes(0, axis)
        return A * B
    

    这是一个例子

    In [33]: A = np.random.rand(3,5)*10; A = A.astype(int); A
    Out[33]: 
    array([[7, 1, 4, 3, 1],
           [1, 8, 8, 2, 4],
           [7, 4, 8, 0, 2]])
    
    In [34]: B = np.linspace(3,7,5); B
    Out[34]: array([3., 4., 5., 6., 7.])
    
    In [35]: multiply_along_axis(A, B, axis=1)
    Out[34]: 
    array([[21.,  4., 20., 18.,  7.],
           [ 3., 32., 40., 12., 28.],
           [21., 16., 40.,  0., 14.]])
    

    【讨论】:

      【解决方案4】:

      避免复制数据和浪费资源!

      利用转换和视图,而不是将数据实际复制 N 次到具有适当形状的新数组中(就像现有答案那样),内存效率更高。这是这样一个方法(基于@ShuxuanXU的代码):

      def mult_along_axis(A, B, axis):
      
          # ensure we're working with Numpy arrays
          A = np.array(A)
          B = np.array(B)
      
          # shape check
          if axis >= A.ndim:
              raise AxisError(axis, A.ndim)
          if A.shape[axis] != B.size:
              raise ValueError(
                  "Length of 'A' along the given axis must be the same as B.size"
                  )
      
          # np.broadcast_to puts the new axis as the last axis, so 
          # we swap the given axis with the last one, to determine the
          # corresponding array shape. np.swapaxes only returns a view
          # of the supplied array, so no data is copied unnecessarily.
          shape = np.swapaxes(A, A.ndim-1, axis).shape
      
          # Broadcast to an array with the shape as above. Again, 
          # no data is copied, we only get a new look at the existing data.
          B_brc = np.broadcast_to(B, shape)
      
          # Swap back the axes. As before, this only changes our "point of view".
          B_brc = np.swapaxes(B_brc, A.ndim-1, axis)
      
          return A * B_brc
      

      【讨论】:

        【解决方案5】:

        你也可以使用简单的矩阵技巧

        c = np.matmul(a,diag(b))
        

        基本上只是在a 和对角线是b 的元素的矩阵之间进行矩阵乘法。也许效率不高,但它是一个不错的单线解决方案

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2019-10-29
          • 2011-06-29
          • 2016-05-25
          • 2021-03-02
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多