【问题标题】:Multiplication by broadcasting rows of 2D array to each slice of 3D array using np.einsum通过使用 np.einsum 将 2D 数组的行广播到 3D 数组的每个切片来进行乘法
【发布时间】:2020-08-16 11:06:57
【问题描述】:

我有数组AB

>>> import numpy as np

>>> A = np.ones((3,3,2))

>>> B = np.array([
    [0,0],
    [1,1],
    [2,2],
])

我想将B的每一行乘以A的每一切片,这样B的每一行就在A的每一切片上进行广播,即:

>>> np.array([A_slice*B_row for A_slice, B_row in zip(A, B)])
[[[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[2. 2.]
  [2. 2.]
  [2. 2.]]]

我想要最有效的方法,我相信这可能是使用np.einsum(但是,如果您认为使用另一种方法更快,例如我在下面提到的方法,请告诉我)。

我尝试了以下方法:

>>> np.einsum('ijk,lk->ijk', A, B)
[[[3. 3.]
  [3. 3.]
  [3. 3.]]

 [[3. 3.]
  [3. 3.]
  [3. 3.]]

 [[3. 3.]
  [3. 3.]
  [3. 3.]]]

如您所见,这显然与上面的输出不同。

我能想到的另一个解决方案是:

>>> A*B[:,np.newaxis,:].repeat(3, axis=1)
[[[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[2. 2.]
  [2. 2.]
  [2. 2.]]]

哪个确实给出了正确的输出,但我仍然很想知道如何使用np.einsum 来做到这一点

编辑: Warren Weckesser 在 cmets 中指出,上面的解决方案可以简化为 A*B[:,np.newaxis,:],这是我见过的不使用 np.einsum 的最干净的解决方案。

【问题讨论】:

  • B 是 (4,2),A 是 (3,3,2)。 B 的最后一行应该发生什么?
  • 抱歉,我忘记编辑了。 B 应该是 (3,2),谢谢提醒
  • 您可以通过删除repeat: A*B[:, np.newaxis, :] 来简化您的最后一个解决方案。您使用repeat 完成的工作实际上就是广播的作用。
  • 很好,我完全忽略了这一点。

标签: python arrays numpy broadcast multiplication


【解决方案1】:

numpy.einsum 解决方案:

C = np.einsum('ijk,jk->jik', A, B)

使用省略号:

C = np.einsum('ij...,j...->ji...', A, B)

输出

[[[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[2. 2.]
  [2. 2.]
  [2. 2.]]]

【讨论】:

  • 感谢您的尝试,这确实适用于我说明的情况,但是我尝试将问题“扩展”一点:A = np.ones((4,3,2))B = np.array([[0,0],[1,1],[2,2],[3,3]]) 并出现错误提示 @987654327 @
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2023-04-03
  • 1970-01-01
  • 1970-01-01
  • 2022-10-04
  • 2021-01-21
  • 1970-01-01
  • 2021-02-07
相关资源
最近更新 更多