【发布时间】:2021-12-05 10:57:51
【问题描述】:
当我想取 2、2D numpy 矩阵的点积时,它会按预期工作
>>> a = np.random.rand(20,10)
>>> b = np.random.rand(10,3)
>>> dotP = np.dot(a,b)
>>> np.shape(dotP)
(20, 3)
在我的用例中,我想做一个类似的操作,但使用更高维度的数组。在下面的示例中,a(20 和 10)的前 2 个维度等效于前面示例中矩阵的 dims。最后 3 个维度在此处介绍(6,5,4)。类似地,在b 矩阵中,10 和 3 来自前面的示例,并引入了 (6,5,4)。当我在以下示例中使用点操作时,我得到以下输出。
>>> a = np.random.rand(20,10,6,5,4)
>>> b = np.random.rand(10,6,5,4,3)
>>> dotP = np.dot(a,b)
>>> np.shape(dotP)
(20, 10, 6, 5, 10, 6, 5, 3)
我想要实现的是如下乘法运算:
>>> dotP = np.dot(a,b)
>>> np.shape(dotP)
(20, 3, 6, 5, 4)
一种可能的解决方案是在 for 循环中进行广播,但我不确定这是否是最好的方法:
dotP = []
for x in range(np.shape(b)[-1]):
dotP.append(np.sum(a * b[:, :, :, :, x], axis=1).reshape((20,1,6,5,4)))
dotP = tuple(dotP)
dotP = np.hstack(dotP)
【问题讨论】:
-
查看
np.matmul。仔细阅读文档。